diff --git a/core/translate/logical.rs b/core/translate/logical.rs new file mode 100644 index 000000000..df8bdd13a --- /dev/null +++ b/core/translate/logical.rs @@ -0,0 +1,3076 @@ +//! Logical plan representation for SQL queries +//! +//! This module provides a platform-independent intermediate representation +//! for SQL queries. The logical plan is a DAG (Directed Acyclic Graph) that +//! supports CTEs and can be used for query optimization before being compiled +//! to an execution plan (e.g., DBSP circuits). +//! +//! The main entry point is `LogicalPlanBuilder` which constructs logical plans +//! from SQL AST nodes. +use crate::function::AggFunc; +use crate::schema::{Schema, Type}; +use crate::types::Value; +use crate::{LimboError, Result}; +use std::collections::HashMap; +use std::fmt::{self, Display, Formatter}; +use std::sync::Arc; +use turso_parser::ast; + +/// Result type for preprocessing aggregate expressions +type PreprocessAggregateResult = ( + bool, // needs_pre_projection + Vec, // pre_projection_exprs + Vec<(String, Type)>, // pre_projection_schema + Vec, // modified_aggr_exprs +); + +/// Schema information for logical plan nodes +#[derive(Debug, Clone, PartialEq)] +pub struct LogicalSchema { + /// Column names and types + pub columns: Vec<(String, Type)>, +} +/// A reference to a schema that can be shared between nodes +pub type SchemaRef = Arc; + +impl LogicalSchema { + pub fn new(columns: Vec<(String, Type)>) -> Self { + Self { columns } + } + + pub fn empty() -> Self { + Self { + columns: Vec::new(), + } + } + + pub fn column_count(&self) -> usize { + self.columns.len() + } + + pub fn find_column(&self, name: &str) -> Option<(usize, &Type)> { + self.columns + .iter() + .position(|(n, _)| n == name) + .map(|idx| (idx, &self.columns[idx].1)) + } +} + +/// Logical representation of a SQL query plan +#[derive(Debug, Clone, PartialEq)] +pub enum LogicalPlan { + /// Projection - SELECT expressions + Projection(Projection), + /// Filter - WHERE/HAVING clause + Filter(Filter), + /// Aggregate - GROUP BY with aggregate functions + Aggregate(Aggregate), + // TODO: Join - combining two relations (not yet implemented) + // Join(Join), + /// Sort - ORDER BY clause + Sort(Sort), + /// Limit - LIMIT/OFFSET clause + Limit(Limit), + /// Table scan - reading from a base table + TableScan(TableScan), + /// Union - UNION/UNION ALL/INTERSECT/EXCEPT + Union(Union), + /// Distinct - remove duplicates + Distinct(Distinct), + /// Empty relation - no rows + EmptyRelation(EmptyRelation), + /// Values - literal rows (VALUES clause) + Values(Values), + /// CTE support - WITH clause + WithCTE(WithCTE), + /// Reference to a CTE + CTERef(CTERef), +} + +impl LogicalPlan { + /// Get the schema of this plan node + pub fn schema(&self) -> &SchemaRef { + match self { + LogicalPlan::Projection(p) => &p.schema, + LogicalPlan::Filter(f) => f.input.schema(), + LogicalPlan::Aggregate(a) => &a.schema, + // LogicalPlan::Join(j) => &j.schema, + LogicalPlan::Sort(s) => s.input.schema(), + LogicalPlan::Limit(l) => l.input.schema(), + LogicalPlan::TableScan(t) => &t.schema, + LogicalPlan::Union(u) => &u.schema, + LogicalPlan::Distinct(d) => d.input.schema(), + LogicalPlan::EmptyRelation(e) => &e.schema, + LogicalPlan::Values(v) => &v.schema, + LogicalPlan::WithCTE(w) => w.body.schema(), + LogicalPlan::CTERef(c) => &c.schema, + } + } +} + +/// Projection operator - SELECT expressions +#[derive(Debug, Clone, PartialEq)] +pub struct Projection { + pub input: Arc, + pub exprs: Vec, + pub schema: SchemaRef, +} + +/// Filter operator - WHERE/HAVING predicates +#[derive(Debug, Clone, PartialEq)] +pub struct Filter { + pub input: Arc, + pub predicate: LogicalExpr, +} + +/// Aggregate operator - GROUP BY with aggregations +#[derive(Debug, Clone, PartialEq)] +pub struct Aggregate { + pub input: Arc, + pub group_expr: Vec, + pub aggr_expr: Vec, + pub schema: SchemaRef, +} + +// TODO: Join operator (not yet implemented) +// #[derive(Debug, Clone, PartialEq)] +// pub struct Join { +// pub left: Arc, +// pub right: Arc, +// pub join_type: JoinType, +// pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions +// pub filter: Option, // Additional filter conditions +// pub schema: SchemaRef, +// } + +// TODO: Types of joins (not yet implemented) +// #[derive(Debug, Clone, Copy, PartialEq)] +// pub enum JoinType { +// Inner, +// Left, +// Right, +// Full, +// Cross, +// } + +/// Sort operator - ORDER BY +#[derive(Debug, Clone, PartialEq)] +pub struct Sort { + pub input: Arc, + pub exprs: Vec, +} + +/// Sort expression with direction +#[derive(Debug, Clone, PartialEq)] +pub struct SortExpr { + pub expr: LogicalExpr, + pub asc: bool, + pub nulls_first: bool, +} + +/// Limit operator - LIMIT/OFFSET +#[derive(Debug, Clone, PartialEq)] +pub struct Limit { + pub input: Arc, + pub skip: Option, + pub fetch: Option, +} + +/// Table scan operator +#[derive(Debug, Clone, PartialEq)] +pub struct TableScan { + pub table_name: String, + pub schema: SchemaRef, + pub projection: Option>, // Column indices to project +} + +/// Union operator +#[derive(Debug, Clone, PartialEq)] +pub struct Union { + pub inputs: Vec>, + pub all: bool, // true for UNION ALL, false for UNION + pub schema: SchemaRef, +} + +/// Distinct operator +#[derive(Debug, Clone, PartialEq)] +pub struct Distinct { + pub input: Arc, +} + +/// Empty relation - produces no rows +#[derive(Debug, Clone, PartialEq)] +pub struct EmptyRelation { + pub produce_one_row: bool, + pub schema: SchemaRef, +} + +/// Values operator - literal rows +#[derive(Debug, Clone, PartialEq)] +pub struct Values { + pub rows: Vec>, + pub schema: SchemaRef, +} + +/// WITH clause - CTEs +#[derive(Debug, Clone, PartialEq)] +pub struct WithCTE { + pub ctes: HashMap>, + pub body: Arc, +} + +/// Reference to a CTE +#[derive(Debug, Clone, PartialEq)] +pub struct CTERef { + pub name: String, + pub schema: SchemaRef, +} + +/// Logical expression representation +#[derive(Debug, Clone, PartialEq)] +pub enum LogicalExpr { + /// Column reference + Column(Column), + /// Literal value + Literal(Value), + /// Binary expression + BinaryExpr { + left: Box, + op: BinaryOperator, + right: Box, + }, + /// Unary expression + UnaryExpr { + op: UnaryOperator, + expr: Box, + }, + /// Aggregate function + AggregateFunction { + fun: AggregateFunction, + args: Vec, + distinct: bool, + }, + /// Scalar function call + ScalarFunction { fun: String, args: Vec }, + /// CASE expression + Case { + expr: Option>, + when_then: Vec<(LogicalExpr, LogicalExpr)>, + else_expr: Option>, + }, + /// IN list + InList { + expr: Box, + list: Vec, + negated: bool, + }, + /// IN subquery + InSubquery { + expr: Box, + subquery: Arc, + negated: bool, + }, + /// EXISTS subquery + Exists { + subquery: Arc, + negated: bool, + }, + /// Scalar subquery + ScalarSubquery(Arc), + /// Alias for an expression + Alias { + expr: Box, + alias: String, + }, + /// IS NULL / IS NOT NULL + IsNull { + expr: Box, + negated: bool, + }, + /// BETWEEN + Between { + expr: Box, + low: Box, + high: Box, + negated: bool, + }, + /// LIKE pattern matching + Like { + expr: Box, + pattern: Box, + escape: Option, + negated: bool, + }, +} + +/// Column reference +#[derive(Debug, Clone, PartialEq)] +pub struct Column { + pub name: String, + pub table: Option, +} + +impl Column { + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + table: None, + } + } + + pub fn with_table(name: impl Into, table: impl Into) -> Self { + Self { + name: name.into(), + table: Some(table.into()), + } + } +} + +impl Display for Column { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match &self.table { + Some(t) => write!(f, "{}.{}", t, self.name), + None => write!(f, "{}", self.name), + } + } +} + +/// Type alias for binary operators +pub type BinaryOperator = ast::Operator; + +/// Type alias for unary operators +pub type UnaryOperator = ast::UnaryOperator; + +/// Type alias for aggregate functions +pub type AggregateFunction = AggFunc; + +/// Compiler from AST to LogicalPlan +pub struct LogicalPlanBuilder<'a> { + schema: &'a Schema, + ctes: HashMap>, +} + +impl<'a> LogicalPlanBuilder<'a> { + pub fn new(schema: &'a Schema) -> Self { + Self { + schema, + ctes: HashMap::new(), + } + } + + /// Main entry point: compile a statement to a logical plan + pub fn build_statement(&mut self, stmt: &ast::Stmt) -> Result { + match stmt { + ast::Stmt::Select(select) => self.build_select(select), + _ => Err(LimboError::ParseError( + "Only SELECT statements are currently supported in logical plans".to_string(), + )), + } + } + + // Convert Name to String + fn name_to_string(name: &ast::Name) -> String { + match name { + ast::Name::Ident(s) | ast::Name::Quoted(s) => s.clone(), + } + } + + // Build a SELECT statement + // Build a logical plan from a SELECT statement + fn build_select(&mut self, select: &ast::Select) -> Result { + // Handle WITH clause if present + if let Some(with) = &select.with { + return self.build_with_cte(with, select); + } + + // Build the main query body + let order_by = &select.order_by; + let limit = &select.limit; + self.build_select_body(&select.body, order_by, limit) + } + + // Build WITH CTE + fn build_with_cte(&mut self, with: &ast::With, select: &ast::Select) -> Result { + let mut cte_plans = HashMap::new(); + + // Build each CTE + for cte in &with.ctes { + let cte_plan = self.build_select(&cte.select)?; + let cte_name = Self::name_to_string(&cte.tbl_name); + cte_plans.insert(cte_name.clone(), Arc::new(cte_plan)); + self.ctes + .insert(cte_name.clone(), cte_plans[&cte_name].clone()); + } + + // Build the main body with CTEs available + let order_by = &select.order_by; + let limit = &select.limit; + let body = self.build_select_body(&select.body, order_by, limit)?; + + // Clear CTEs from builder context + for cte in &with.ctes { + self.ctes.remove(&Self::name_to_string(&cte.tbl_name)); + } + + Ok(LogicalPlan::WithCTE(WithCTE { + ctes: cte_plans, + body: Arc::new(body), + })) + } + + // Build SELECT body + fn build_select_body( + &mut self, + body: &ast::SelectBody, + order_by: &[ast::SortedColumn], + limit: &Option, + ) -> Result { + let mut plan = self.build_one_select(&body.select)?; + + // Handle compound operators (UNION, INTERSECT, EXCEPT) + if !body.compounds.is_empty() { + for compound in &body.compounds { + let right = self.build_one_select(&compound.select)?; + plan = Self::build_compound(plan, right, &compound.operator)?; + } + } + + // Apply ORDER BY + if !order_by.is_empty() { + plan = self.build_sort(plan, order_by)?; + } + + // Apply LIMIT + if let Some(limit) = limit { + plan = Self::build_limit(plan, limit)?; + } + + Ok(plan) + } + + // Build a single SELECT (without compounds) + fn build_one_select(&mut self, select: &ast::OneSelect) -> Result { + match select { + ast::OneSelect::Select { + distinctness, + columns, + from, + where_clause, + group_by, + window_clause: _, + } => { + // Start with FROM clause + let mut plan = if let Some(from) = from { + self.build_from(from)? + } else { + // No FROM clause - single row + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::new(LogicalSchema::empty()), + }) + }; + + // Apply WHERE + if let Some(where_expr) = where_clause { + let predicate = self.build_expr(where_expr, plan.schema())?; + plan = LogicalPlan::Filter(Filter { + input: Arc::new(plan), + predicate, + }); + } + + // Apply GROUP BY and aggregations + if let Some(group_by) = group_by { + plan = self.build_aggregate(plan, group_by, columns)?; + } else if Self::has_aggregates(columns) { + // Aggregation without GROUP BY + plan = self.build_aggregate_no_group(plan, columns)?; + } else { + // Regular projection + plan = self.build_projection(plan, columns)?; + } + + // Apply HAVING (part of GROUP BY) + if let Some(ref group_by) = group_by { + if let Some(ref having_expr) = group_by.having { + let predicate = self.build_expr(having_expr, plan.schema())?; + plan = LogicalPlan::Filter(Filter { + input: Arc::new(plan), + predicate, + }); + } + } + + // Apply DISTINCT + if distinctness.is_some() { + plan = LogicalPlan::Distinct(Distinct { + input: Arc::new(plan), + }); + } + + Ok(plan) + } + ast::OneSelect::Values(values) => self.build_values(values), + } + } + + // Build FROM clause + fn build_from(&mut self, from: &ast::FromClause) -> Result { + let mut plan = { self.build_select_table(&from.select)? }; + + // Handle JOINs + if !from.joins.is_empty() { + for join in &from.joins { + let right = self.build_select_table(&join.table)?; + plan = self.build_join(plan, right, &join.operator, &join.constraint)?; + } + } + + Ok(plan) + } + + // Build a table reference + fn build_select_table(&mut self, table: &ast::SelectTable) -> Result { + match table { + ast::SelectTable::Table(name, _alias, _indexed) => { + let table_name = Self::name_to_string(&name.name); + // Check if it's a CTE reference + if let Some(cte_plan) = self.ctes.get(&table_name) { + return Ok(LogicalPlan::CTERef(CTERef { + name: table_name.clone(), + schema: cte_plan.schema().clone(), + })); + } + + // Regular table scan + let table_schema = self.get_table_schema(&table_name)?; + Ok(LogicalPlan::TableScan(TableScan { + table_name, + schema: table_schema, + projection: None, + })) + } + ast::SelectTable::Select(subquery, _alias) => self.build_select(subquery), + ast::SelectTable::TableCall(_, _, _) => Err(LimboError::ParseError( + "Table-valued functions are not supported in logical plans".to_string(), + )), + ast::SelectTable::Sub(_, _) => Err(LimboError::ParseError( + "Subquery in FROM clause not yet supported".to_string(), + )), + } + } + + // Build JOIN + fn build_join( + &mut self, + _left: LogicalPlan, + _right: LogicalPlan, + _op: &ast::JoinOperator, + _constraint: &Option, + ) -> Result { + Err(LimboError::ParseError( + "JOINs are not yet supported in logical plans".to_string(), + )) + } + + // Build projection + fn build_projection( + &mut self, + input: LogicalPlan, + columns: &[ast::ResultColumn], + ) -> Result { + let input_schema = input.schema(); + let mut proj_exprs = Vec::new(); + let mut schema_columns = Vec::new(); + + for col in columns { + match col { + ast::ResultColumn::Expr(expr, alias) => { + let logical_expr = self.build_expr(expr, input_schema)?; + let col_name = match alias { + Some(as_alias) => match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }, + None => Self::expr_to_column_name(expr), + }; + let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; + + schema_columns.push((col_name.clone(), col_type)); + + if let Some(as_alias) = alias { + let alias_name = match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }; + proj_exprs.push(LogicalExpr::Alias { + expr: Box::new(logical_expr), + alias: alias_name, + }); + } else { + proj_exprs.push(logical_expr); + } + } + ast::ResultColumn::Star => { + // Expand * to all columns + for (name, typ) in &input_schema.columns { + proj_exprs.push(LogicalExpr::Column(Column::new(name.clone()))); + schema_columns.push((name.clone(), *typ)); + } + } + ast::ResultColumn::TableStar(table) => { + // Expand table.* to all columns from that table + let table_name = Self::name_to_string(table); + for (name, typ) in &input_schema.columns { + // Simple check - would need proper table tracking in real implementation + proj_exprs.push(LogicalExpr::Column(Column::with_table( + name.clone(), + table_name.clone(), + ))); + schema_columns.push((name.clone(), *typ)); + } + } + } + } + + Ok(LogicalPlan::Projection(Projection { + input: Arc::new(input), + exprs: proj_exprs, + schema: Arc::new(LogicalSchema::new(schema_columns)), + })) + } + + // Helper function to preprocess aggregate expressions that contain complex arguments + // Returns: (needs_pre_projection, pre_projection_exprs, pre_projection_schema, modified_aggr_exprs) + // + // This will be used in expressions like select sum(hex(a + 2)) from tbl => hex(a + 2) is a + // pre-projection. + // + // Another alternative is to always generate a projection together with an aggregation, and + // just have "a" be the identity projection if we don't have a complex case. But that's quite + // wasteful. + fn preprocess_aggregate_expressions( + aggr_exprs: &[LogicalExpr], + group_exprs: &[LogicalExpr], + input_schema: &SchemaRef, + ) -> Result { + let mut needs_pre_projection = false; + let mut pre_projection_exprs = Vec::new(); + let mut pre_projection_schema = Vec::new(); + let mut modified_aggr_exprs = Vec::new(); + let mut projected_col_counter = 0; + + // First, add all group by expressions to the pre-projection + for expr in group_exprs { + if let LogicalExpr::Column(col) = expr { + pre_projection_exprs.push(expr.clone()); + let col_type = Self::infer_expr_type(expr, input_schema)?; + pre_projection_schema.push((col.name.clone(), col_type)); + } else { + // Complex group by expression - project it + needs_pre_projection = true; + let proj_col_name = format!("__group_proj_{projected_col_counter}"); + projected_col_counter += 1; + pre_projection_exprs.push(expr.clone()); + let col_type = Self::infer_expr_type(expr, input_schema)?; + pre_projection_schema.push((proj_col_name.clone(), col_type)); + } + } + + // Check each aggregate expression + for agg_expr in aggr_exprs { + if let LogicalExpr::AggregateFunction { + fun, + args, + distinct, + } = agg_expr + { + let mut modified_args = Vec::new(); + for arg in args { + // Check if the argument is a simple column reference or a complex expression + match arg { + LogicalExpr::Column(_) => { + // Simple column - just use it + modified_args.push(arg.clone()); + // Make sure the column is in the pre-projection + if !pre_projection_exprs.iter().any(|e| e == arg) { + pre_projection_exprs.push(arg.clone()); + let col_type = Self::infer_expr_type(arg, input_schema)?; + if let LogicalExpr::Column(col) = arg { + pre_projection_schema.push((col.name.clone(), col_type)); + } + } + } + _ => { + // Complex expression - we need to project it first + needs_pre_projection = true; + let proj_col_name = format!("__agg_arg_proj_{projected_col_counter}"); + projected_col_counter += 1; + + // Add the expression to the pre-projection + pre_projection_exprs.push(arg.clone()); + let col_type = Self::infer_expr_type(arg, input_schema)?; + pre_projection_schema.push((proj_col_name.clone(), col_type)); + + // In the aggregate, reference the projected column + modified_args.push(LogicalExpr::Column(Column::new(proj_col_name))); + } + } + } + + // Create the modified aggregate expression + modified_aggr_exprs.push(LogicalExpr::AggregateFunction { + fun: fun.clone(), + args: modified_args, + distinct: *distinct, + }); + } else { + modified_aggr_exprs.push(agg_expr.clone()); + } + } + + Ok(( + needs_pre_projection, + pre_projection_exprs, + pre_projection_schema, + modified_aggr_exprs, + )) + } + + // Build aggregate with GROUP BY + fn build_aggregate( + &mut self, + input: LogicalPlan, + group_by: &ast::GroupBy, + columns: &[ast::ResultColumn], + ) -> Result { + let input_schema = input.schema(); + + // Build grouping expressions + let mut group_exprs = Vec::new(); + for expr in &group_by.exprs { + group_exprs.push(self.build_expr(expr, input_schema)?); + } + + // Use the unified aggregate builder + self.build_aggregate_internal(input, group_exprs, columns) + } + + // Build aggregate without GROUP BY + fn build_aggregate_no_group( + &mut self, + input: LogicalPlan, + columns: &[ast::ResultColumn], + ) -> Result { + // Use the unified aggregate builder with empty group expressions + self.build_aggregate_internal(input, vec![], columns) + } + + // Unified internal aggregate builder that handles both GROUP BY and non-GROUP BY cases + fn build_aggregate_internal( + &mut self, + input: LogicalPlan, + group_exprs: Vec, + columns: &[ast::ResultColumn], + ) -> Result { + let input_schema = input.schema(); + let has_group_by = !group_exprs.is_empty(); + + // Build aggregate expressions and projection expressions + let mut aggr_exprs = Vec::new(); + let mut projection_exprs = Vec::new(); + let mut aggregate_schema_columns = Vec::new(); + + // First, add GROUP BY columns to the aggregate output schema + // These are always part of the aggregate operator's output + for group_expr in &group_exprs { + let col_name = match group_expr { + LogicalExpr::Column(col) => col.name.clone(), + _ => { + // For complex GROUP BY expressions, generate a name + format!("__group_{}", aggregate_schema_columns.len()) + } + }; + let col_type = Self::infer_expr_type(group_expr, input_schema)?; + aggregate_schema_columns.push((col_name, col_type)); + } + + // Track aggregates we've already seen to avoid duplicates + let mut aggregate_map: std::collections::HashMap = + std::collections::HashMap::new(); + + for col in columns { + match col { + ast::ResultColumn::Expr(expr, alias) => { + let logical_expr = self.build_expr(expr, input_schema)?; + + // Determine the column name for this expression + let col_name = match alias { + Some(as_alias) => match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }, + None => Self::expr_to_column_name(expr), + }; + + // Check if the TOP-LEVEL expression is an aggregate + // We only care about immediate aggregates, not nested ones + if Self::is_aggregate_expr(&logical_expr) { + // Pure aggregate function - check if we've seen it before + let agg_key = format!("{logical_expr:?}"); + + let agg_col_name = if let Some(existing_name) = aggregate_map.get(&agg_key) + { + // Reuse existing aggregate + existing_name.clone() + } else { + // New aggregate - add it + let col_type = Self::infer_expr_type(&logical_expr, input_schema)?; + aggregate_schema_columns.push((col_name.clone(), col_type)); + aggr_exprs.push(logical_expr); + aggregate_map.insert(agg_key, col_name.clone()); + col_name.clone() + }; + + // In the projection, reference this aggregate by name + projection_exprs.push(LogicalExpr::Column(Column { + name: agg_col_name, + table: None, + })); + } else if Self::contains_aggregate(&logical_expr) { + // This is an expression that contains an aggregate somewhere + // (e.g., sum(a + 2) * 2) + // We need to extract aggregates and replace them with column references + let (processed_expr, extracted_aggs) = + Self::extract_and_replace_aggregates_with_dedup( + logical_expr, + &mut aggregate_map, + )?; + + // Add only new aggregates + for (agg_expr, agg_name) in extracted_aggs { + let agg_type = Self::infer_expr_type(&agg_expr, input_schema)?; + aggregate_schema_columns.push((agg_name, agg_type)); + aggr_exprs.push(agg_expr); + } + + // Add the processed expression (with column refs) to projection + projection_exprs.push(processed_expr); + } else { + // Non-aggregate expression - validation depends on GROUP BY presence + if has_group_by { + // With GROUP BY: only allow constants and grouped columns + // TODO: SQLite actually allows any column here and returns the value from + // the first row encountered in each group. We should support this in the + // future for full SQLite compatibility, but for now we're stricter to + // simplify the DBSP compilation. + if !Self::is_constant_expr(&logical_expr) + && !Self::is_valid_in_group_by(&logical_expr, &group_exprs) + { + return Err(LimboError::ParseError(format!( + "Column '{col_name}' must appear in the GROUP BY clause or be used in an aggregate function" + ))); + } + } else { + // Without GROUP BY: only allow constant expressions + // TODO: SQLite allows any column here and returns a value from an + // arbitrary row. We should support this for full compatibility, + // but for now we're stricter to simplify DBSP compilation. + if !Self::is_constant_expr(&logical_expr) { + return Err(LimboError::ParseError(format!( + "Column '{col_name}' must be used in an aggregate function when using aggregates without GROUP BY" + ))); + } + } + projection_exprs.push(logical_expr); + } + } + _ => { + let error_msg = if has_group_by { + "* not supported with GROUP BY".to_string() + } else { + "* not supported with aggregate functions".to_string() + }; + return Err(LimboError::ParseError(error_msg)); + } + } + } + + // Check if any aggregate functions have complex expressions as arguments + // If so, we need to insert a projection before the aggregate + let ( + needs_pre_projection, + pre_projection_exprs, + pre_projection_schema, + modified_aggr_exprs, + ) = Self::preprocess_aggregate_expressions(&aggr_exprs, &group_exprs, input_schema)?; + + // Build the final schema for the projection + let mut projection_schema_columns = Vec::new(); + for (i, expr) in projection_exprs.iter().enumerate() { + let col_name = if i < columns.len() { + match &columns[i] { + ast::ResultColumn::Expr(e, alias) => match alias { + Some(as_alias) => match as_alias { + ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name), + }, + None => Self::expr_to_column_name(e), + }, + _ => format!("col_{i}"), + } + } else { + format!("col_{i}") + }; + + // For type inference, we need the aggregate schema for column references + let aggregate_schema = LogicalSchema::new(aggregate_schema_columns.clone()); + let col_type = Self::infer_expr_type(expr, &Arc::new(aggregate_schema))?; + projection_schema_columns.push((col_name, col_type)); + } + + // Create the input plan (with pre-projection if needed) + let aggregate_input = if needs_pre_projection { + Arc::new(LogicalPlan::Projection(Projection { + input: Arc::new(input), + exprs: pre_projection_exprs, + schema: Arc::new(LogicalSchema::new(pre_projection_schema)), + })) + } else { + Arc::new(input) + }; + + // Use modified aggregate expressions if we inserted a pre-projection + let final_aggr_exprs = if needs_pre_projection { + modified_aggr_exprs + } else { + aggr_exprs + }; + + // Check if we need the outer projection + // We need a projection if: + // 1. Any expression is more complex than a simple column reference (e.g., abs(sum(id))) + // 2. We're selecting a different set of columns than what the aggregate outputs + // 3. Columns are renamed or reordered + let needs_outer_projection = { + // Check if any expression is more complex than a simple column reference + let has_complex_exprs = projection_exprs + .iter() + .any(|expr| !matches!(expr, LogicalExpr::Column(_))); + + if has_complex_exprs { + true + } else { + // All are simple columns - check if we're selecting exactly what the aggregate outputs + // The projection might be selecting a subset (e.g., only aggregates without group columns) + // or reordering columns, or using different names + + // For now, keep it simple: if schemas don't match exactly, we need projection + // This handles all cases: subset selection, reordering, renaming + projection_schema_columns != aggregate_schema_columns + } + }; + + // Create the aggregate node + let aggregate_plan = LogicalPlan::Aggregate(Aggregate { + input: aggregate_input, + group_expr: group_exprs, + aggr_expr: final_aggr_exprs, + schema: Arc::new(LogicalSchema::new(aggregate_schema_columns)), + }); + + if needs_outer_projection { + Ok(LogicalPlan::Projection(Projection { + input: Arc::new(aggregate_plan), + exprs: projection_exprs, + schema: Arc::new(LogicalSchema::new(projection_schema_columns)), + })) + } else { + // No projection needed - the aggregate output is exactly what we want + Ok(aggregate_plan) + } + } + + /// Build VALUES clause + #[allow(clippy::vec_box)] + fn build_values(&mut self, values: &[Vec>]) -> Result { + if values.is_empty() { + return Err(LimboError::ParseError("Empty VALUES clause".to_string())); + } + + let mut rows = Vec::new(); + let first_row_len = values[0].len(); + + // Infer schema from first row + let mut schema_columns = Vec::new(); + for (i, _) in values[0].iter().enumerate() { + schema_columns.push((format!("column{}", i + 1), Type::Text)); + } + + for row in values { + if row.len() != first_row_len { + return Err(LimboError::ParseError( + "All rows in VALUES must have the same number of columns".to_string(), + )); + } + + let mut logical_row = Vec::new(); + for expr in row { + // VALUES doesn't have input schema + let empty_schema = Arc::new(LogicalSchema::empty()); + logical_row.push(self.build_expr(expr, &empty_schema)?); + } + rows.push(logical_row); + } + + Ok(LogicalPlan::Values(Values { + rows, + schema: Arc::new(LogicalSchema::new(schema_columns)), + })) + } + + // Build SORT + fn build_sort( + &mut self, + input: LogicalPlan, + exprs: &[ast::SortedColumn], + ) -> Result { + let input_schema = input.schema(); + let mut sort_exprs = Vec::new(); + + for sorted_col in exprs { + let expr = self.build_expr(&sorted_col.expr, input_schema)?; + sort_exprs.push(SortExpr { + expr, + asc: sorted_col.order != Some(ast::SortOrder::Desc), + nulls_first: sorted_col.nulls == Some(ast::NullsOrder::First), + }); + } + + Ok(LogicalPlan::Sort(Sort { + input: Arc::new(input), + exprs: sort_exprs, + })) + } + + // Build LIMIT + fn build_limit(input: LogicalPlan, limit: &ast::Limit) -> Result { + let fetch = match limit.expr.as_ref() { + ast::Expr::Literal(ast::Literal::Numeric(s)) => s.parse::().ok(), + _ => { + return Err(LimboError::ParseError( + "LIMIT must be a literal integer".to_string(), + )) + } + }; + + let skip = if let Some(offset) = &limit.offset { + match offset.as_ref() { + ast::Expr::Literal(ast::Literal::Numeric(s)) => s.parse::().ok(), + _ => { + return Err(LimboError::ParseError( + "OFFSET must be a literal integer".to_string(), + )) + } + } + } else { + None + }; + + Ok(LogicalPlan::Limit(Limit { + input: Arc::new(input), + skip, + fetch, + })) + } + + // Build compound operator (UNION, INTERSECT, EXCEPT) + fn build_compound( + left: LogicalPlan, + right: LogicalPlan, + op: &ast::CompoundOperator, + ) -> Result { + // Check schema compatibility + if left.schema().column_count() != right.schema().column_count() { + return Err(LimboError::ParseError( + "UNION/INTERSECT/EXCEPT requires same number of columns".to_string(), + )); + } + + let all = matches!(op, ast::CompoundOperator::UnionAll); + + match op { + ast::CompoundOperator::Union | ast::CompoundOperator::UnionAll => { + let schema = left.schema().clone(); + Ok(LogicalPlan::Union(Union { + inputs: vec![Arc::new(left), Arc::new(right)], + all, + schema, + })) + } + _ => Err(LimboError::ParseError( + "INTERSECT and EXCEPT not yet supported in logical plans".to_string(), + )), + } + } + + // Build expression from AST + fn build_expr(&mut self, expr: &ast::Expr, _schema: &SchemaRef) -> Result { + match expr { + ast::Expr::Id(name) => Ok(LogicalExpr::Column(Column::new(Self::name_to_string(name)))), + + ast::Expr::DoublyQualified(db, table, col) => { + Ok(LogicalExpr::Column(Column::with_table( + Self::name_to_string(col), + format!( + "{}.{}", + Self::name_to_string(db), + Self::name_to_string(table) + ), + ))) + } + + ast::Expr::Qualified(table, col) => Ok(LogicalExpr::Column(Column::with_table( + Self::name_to_string(col), + Self::name_to_string(table), + ))), + + ast::Expr::Literal(lit) => Ok(LogicalExpr::Literal(Self::build_literal(lit)?)), + + ast::Expr::Binary(lhs, op, rhs) => { + // Special case: IS NULL and IS NOT NULL + if matches!(op, ast::Operator::Is | ast::Operator::IsNot) { + if let ast::Expr::Literal(ast::Literal::Null) = rhs.as_ref() { + let expr = Box::new(self.build_expr(lhs, _schema)?); + return Ok(LogicalExpr::IsNull { + expr, + negated: matches!(op, ast::Operator::IsNot), + }); + } + } + + let left = Box::new(self.build_expr(lhs, _schema)?); + let right = Box::new(self.build_expr(rhs, _schema)?); + Ok(LogicalExpr::BinaryExpr { + left, + op: *op, + right, + }) + } + + ast::Expr::Unary(op, expr) => { + let inner = Box::new(self.build_expr(expr, _schema)?); + Ok(LogicalExpr::UnaryExpr { + op: *op, + expr: inner, + }) + } + + ast::Expr::FunctionCall { + name, + distinctness, + args, + filter_over, + .. + } => { + // Check for window functions (OVER clause) + if filter_over.over_clause.is_some() { + return Err(LimboError::ParseError( + "Unsupported expression type: window functions are not yet supported" + .to_string(), + )); + } + + let func_name = Self::name_to_string(name); + let arg_count = args.len(); + // Check if it's an aggregate function (considering argument count for min/max) + if let Some(agg_fun) = Self::parse_aggregate_function(&func_name, arg_count) { + let distinct = distinctness.is_some(); + let arg_exprs = args + .iter() + .map(|e| self.build_expr(e, _schema)) + .collect::>>()?; + Ok(LogicalExpr::AggregateFunction { + fun: agg_fun, + args: arg_exprs, + distinct, + }) + } else { + // Regular scalar function + let arg_exprs = args + .iter() + .map(|e| self.build_expr(e, _schema)) + .collect::>>()?; + Ok(LogicalExpr::ScalarFunction { + fun: func_name, + args: arg_exprs, + }) + } + } + + ast::Expr::FunctionCallStar { name, .. } => { + // Handle COUNT(*) and similar + let func_name = Self::name_to_string(name); + // FunctionCallStar always has 0 args (it's the * form) + if let Some(agg_fun) = Self::parse_aggregate_function(&func_name, 0) { + Ok(LogicalExpr::AggregateFunction { + fun: agg_fun, + args: vec![], + distinct: false, + }) + } else { + Err(LimboError::ParseError(format!( + "Function {func_name}(*) is not supported" + ))) + } + } + + ast::Expr::Case { + base, + when_then_pairs, + else_expr, + } => { + let case_expr = if let Some(e) = base { + Some(Box::new(self.build_expr(e, _schema)?)) + } else { + None + }; + + let when_then_exprs = when_then_pairs + .iter() + .map(|(when, then)| { + Ok(( + self.build_expr(when, _schema)?, + self.build_expr(then, _schema)?, + )) + }) + .collect::>>()?; + + let else_result = if let Some(e) = else_expr { + Some(Box::new(self.build_expr(e, _schema)?)) + } else { + None + }; + + Ok(LogicalExpr::Case { + expr: case_expr, + when_then: when_then_exprs, + else_expr: else_result, + }) + } + + ast::Expr::InList { lhs, not, rhs } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let list = rhs + .iter() + .map(|e| self.build_expr(e, _schema)) + .collect::>>()?; + Ok(LogicalExpr::InList { + expr, + list, + negated: *not, + }) + } + + ast::Expr::InSelect { lhs, not, rhs } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let subquery = Arc::new(self.build_select(rhs)?); + Ok(LogicalExpr::InSubquery { + expr, + subquery, + negated: *not, + }) + } + + ast::Expr::Exists(select) => { + let subquery = Arc::new(self.build_select(select)?); + Ok(LogicalExpr::Exists { + subquery, + negated: false, + }) + } + + ast::Expr::Subquery(select) => { + let subquery = Arc::new(self.build_select(select)?); + Ok(LogicalExpr::ScalarSubquery(subquery)) + } + + ast::Expr::IsNull(lhs) => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + Ok(LogicalExpr::IsNull { + expr, + negated: false, + }) + } + + ast::Expr::NotNull(lhs) => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + Ok(LogicalExpr::IsNull { + expr, + negated: true, + }) + } + + ast::Expr::Between { + lhs, + not, + start, + end, + } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let low = Box::new(self.build_expr(start, _schema)?); + let high = Box::new(self.build_expr(end, _schema)?); + Ok(LogicalExpr::Between { + expr, + low, + high, + negated: *not, + }) + } + + ast::Expr::Like { + lhs, + not, + op: _, + rhs, + escape, + } => { + let expr = Box::new(self.build_expr(lhs, _schema)?); + let pattern = Box::new(self.build_expr(rhs, _schema)?); + let escape_char = escape.as_ref().and_then(|e| { + if let ast::Expr::Literal(ast::Literal::String(s)) = e.as_ref() { + s.chars().next() + } else { + None + } + }); + Ok(LogicalExpr::Like { + expr, + pattern, + escape: escape_char, + negated: *not, + }) + } + + ast::Expr::Parenthesized(exprs) => { + // the assumption is that there is at least one parenthesis here. + // If this is not true, then I don't understand this code and can't be trusted. + assert!(!exprs.is_empty()); + // Multiple expressions in parentheses is unusual but handle it + // by building the first one (SQLite behavior) + self.build_expr(&exprs[0], _schema) + } + + _ => Err(LimboError::ParseError(format!( + "Unsupported expression type in logical plan: {expr:?}" + ))), + } + } + + /// Build literal value + fn build_literal(lit: &ast::Literal) -> Result { + match lit { + ast::Literal::Null => Ok(Value::Null), + ast::Literal::Keyword(k) if k.eq_ignore_ascii_case("true") => Ok(Value::Integer(1)), // SQLite uses int for bool + ast::Literal::Keyword(k) if k.eq_ignore_ascii_case("false") => Ok(Value::Integer(0)), // SQLite uses int for bool + ast::Literal::Keyword(k) => Ok(Value::Text(k.clone().into())), + ast::Literal::Numeric(s) => { + if let Ok(i) = s.parse::() { + Ok(Value::Integer(i)) + } else if let Ok(f) = s.parse::() { + Ok(Value::Float(f)) + } else { + Ok(Value::Text(s.clone().into())) + } + } + ast::Literal::String(s) => { + // Strip surrounding quotes from the SQL literal + // The parser includes quotes in the string value + let unquoted = if s.starts_with('\'') && s.ends_with('\'') && s.len() > 1 { + &s[1..s.len() - 1] + } else { + s.as_str() + }; + Ok(Value::Text(unquoted.to_string().into())) + } + ast::Literal::Blob(b) => Ok(Value::Blob(b.clone().into())), + ast::Literal::CurrentDate + | ast::Literal::CurrentTime + | ast::Literal::CurrentTimestamp => Err(LimboError::ParseError( + "Temporal literals not yet supported".to_string(), + )), + } + } + + /// Parse aggregate function name (considering argument count for min/max) + fn parse_aggregate_function(name: &str, arg_count: usize) -> Option { + match name.to_uppercase().as_str() { + "COUNT" => Some(AggFunc::Count), + "SUM" => Some(AggFunc::Sum), + "AVG" => Some(AggFunc::Avg), + // MIN and MAX are only aggregates with 1 argument + // With 2+ arguments, they're scalar functions + "MIN" if arg_count == 1 => Some(AggFunc::Min), + "MAX" if arg_count == 1 => Some(AggFunc::Max), + "GROUP_CONCAT" => Some(AggFunc::GroupConcat), + "STRING_AGG" => Some(AggFunc::StringAgg), + "TOTAL" => Some(AggFunc::Total), + _ => None, + } + } + + // Check if expression contains aggregates + fn has_aggregates(columns: &[ast::ResultColumn]) -> bool { + for col in columns { + if let ast::ResultColumn::Expr(expr, _) = col { + if Self::expr_has_aggregate(expr) { + return true; + } + } + } + false + } + + // Check if AST expression contains aggregates + fn expr_has_aggregate(expr: &ast::Expr) -> bool { + match expr { + ast::Expr::FunctionCall { name, args, .. } => { + // Check if the function itself is an aggregate (considering arg count for min/max) + let arg_count = args.len(); + if Self::parse_aggregate_function(&Self::name_to_string(name), arg_count).is_some() + { + return true; + } + // Also check if any arguments contain aggregates (for nested functions like HEX(SUM(...))) + args.iter().any(|arg| Self::expr_has_aggregate(arg)) + } + ast::Expr::FunctionCallStar { name, .. } => { + // FunctionCallStar always has 0 args + Self::parse_aggregate_function(&Self::name_to_string(name), 0).is_some() + } + ast::Expr::Binary(lhs, _, rhs) => { + Self::expr_has_aggregate(lhs) || Self::expr_has_aggregate(rhs) + } + ast::Expr::Unary(_, e) => Self::expr_has_aggregate(e), + ast::Expr::Case { + when_then_pairs, + else_expr, + .. + } => { + when_then_pairs + .iter() + .any(|(w, t)| Self::expr_has_aggregate(w) || Self::expr_has_aggregate(t)) + || else_expr + .as_ref() + .is_some_and(|e| Self::expr_has_aggregate(e)) + } + ast::Expr::Parenthesized(exprs) => { + // Check if any parenthesized expression contains an aggregate + exprs.iter().any(|e| Self::expr_has_aggregate(e)) + } + _ => false, + } + } + + // Check if logical expression is an aggregate + fn is_aggregate_expr(expr: &LogicalExpr) -> bool { + match expr { + LogicalExpr::AggregateFunction { .. } => true, + LogicalExpr::Alias { expr, .. } => Self::is_aggregate_expr(expr), + _ => false, + } + } + + // Check if logical expression contains an aggregate anywhere + fn contains_aggregate(expr: &LogicalExpr) -> bool { + match expr { + LogicalExpr::AggregateFunction { .. } => true, + LogicalExpr::Alias { expr, .. } => Self::contains_aggregate(expr), + LogicalExpr::BinaryExpr { left, right, .. } => { + Self::contains_aggregate(left) || Self::contains_aggregate(right) + } + LogicalExpr::UnaryExpr { expr, .. } => Self::contains_aggregate(expr), + LogicalExpr::ScalarFunction { args, .. } => args.iter().any(Self::contains_aggregate), + LogicalExpr::Case { + when_then, + else_expr, + .. + } => { + when_then + .iter() + .any(|(w, t)| Self::contains_aggregate(w) || Self::contains_aggregate(t)) + || else_expr + .as_ref() + .is_some_and(|e| Self::contains_aggregate(e)) + } + _ => false, + } + } + + // Check if an expression is a constant (contains only literals) + fn is_constant_expr(expr: &LogicalExpr) -> bool { + match expr { + LogicalExpr::Literal(_) => true, + LogicalExpr::BinaryExpr { left, right, .. } => { + Self::is_constant_expr(left) && Self::is_constant_expr(right) + } + LogicalExpr::UnaryExpr { expr, .. } => Self::is_constant_expr(expr), + LogicalExpr::ScalarFunction { args, .. } => args.iter().all(Self::is_constant_expr), + LogicalExpr::Alias { expr, .. } => Self::is_constant_expr(expr), + _ => false, + } + } + + // Check if an expression is valid in GROUP BY context + // An expression is valid if it's: + // 1. A constant literal + // 2. An aggregate function + // 3. A grouping column (or expression involving only grouping columns) + fn is_valid_in_group_by(expr: &LogicalExpr, group_exprs: &[LogicalExpr]) -> bool { + match expr { + LogicalExpr::Literal(_) => true, // Constants are always valid + LogicalExpr::AggregateFunction { .. } => true, // Aggregates are valid + LogicalExpr::Column(col) => { + // Check if this column is in the GROUP BY + group_exprs.iter().any(|g| match g { + LogicalExpr::Column(gcol) => gcol.name == col.name, + _ => false, + }) + } + LogicalExpr::BinaryExpr { left, right, .. } => { + // Both sides must be valid + Self::is_valid_in_group_by(left, group_exprs) + && Self::is_valid_in_group_by(right, group_exprs) + } + LogicalExpr::UnaryExpr { expr, .. } => Self::is_valid_in_group_by(expr, group_exprs), + LogicalExpr::ScalarFunction { args, .. } => { + // All arguments must be valid + args.iter() + .all(|arg| Self::is_valid_in_group_by(arg, group_exprs)) + } + LogicalExpr::Alias { expr, .. } => Self::is_valid_in_group_by(expr, group_exprs), + _ => false, // Other expressions are not valid + } + } + + // Extract aggregates from an expression and replace them with column references, with deduplication + // Returns the modified expression and a list of NEW (aggregate_expr, column_name) pairs + fn extract_and_replace_aggregates_with_dedup( + expr: LogicalExpr, + aggregate_map: &mut std::collections::HashMap, + ) -> Result<(LogicalExpr, Vec<(LogicalExpr, String)>)> { + let mut new_aggregates = Vec::new(); + let mut counter = aggregate_map.len(); + let new_expr = Self::replace_aggregates_with_columns_dedup( + expr, + &mut new_aggregates, + aggregate_map, + &mut counter, + )?; + Ok((new_expr, new_aggregates)) + } + + // Recursively replace aggregate functions with column references, with deduplication + fn replace_aggregates_with_columns_dedup( + expr: LogicalExpr, + new_aggregates: &mut Vec<(LogicalExpr, String)>, + aggregate_map: &mut std::collections::HashMap, + counter: &mut usize, + ) -> Result { + match expr { + LogicalExpr::AggregateFunction { .. } => { + // Found an aggregate - check if we've seen it before + let agg_key = format!("{expr:?}"); + + let col_name = if let Some(existing_name) = aggregate_map.get(&agg_key) { + // Reuse existing aggregate + existing_name.clone() + } else { + // New aggregate + let col_name = format!("__agg_{}", *counter); + *counter += 1; + aggregate_map.insert(agg_key, col_name.clone()); + new_aggregates.push((expr, col_name.clone())); + col_name + }; + + Ok(LogicalExpr::Column(Column { + name: col_name, + table: None, + })) + } + LogicalExpr::BinaryExpr { left, op, right } => { + let new_left = Self::replace_aggregates_with_columns_dedup( + *left, + new_aggregates, + aggregate_map, + counter, + )?; + let new_right = Self::replace_aggregates_with_columns_dedup( + *right, + new_aggregates, + aggregate_map, + counter, + )?; + Ok(LogicalExpr::BinaryExpr { + left: Box::new(new_left), + op, + right: Box::new(new_right), + }) + } + LogicalExpr::UnaryExpr { op, expr } => { + let new_expr = Self::replace_aggregates_with_columns_dedup( + *expr, + new_aggregates, + aggregate_map, + counter, + )?; + Ok(LogicalExpr::UnaryExpr { + op, + expr: Box::new(new_expr), + }) + } + LogicalExpr::ScalarFunction { fun, args } => { + let mut new_args = Vec::new(); + for arg in args { + new_args.push(Self::replace_aggregates_with_columns_dedup( + arg, + new_aggregates, + aggregate_map, + counter, + )?); + } + Ok(LogicalExpr::ScalarFunction { + fun, + args: new_args, + }) + } + LogicalExpr::Case { + expr: case_expr, + when_then, + else_expr, + } => { + let new_case_expr = if let Some(e) = case_expr { + Some(Box::new(Self::replace_aggregates_with_columns_dedup( + *e, + new_aggregates, + aggregate_map, + counter, + )?)) + } else { + None + }; + + let mut new_when_then = Vec::new(); + for (when, then) in when_then { + let new_when = Self::replace_aggregates_with_columns_dedup( + when, + new_aggregates, + aggregate_map, + counter, + )?; + let new_then = Self::replace_aggregates_with_columns_dedup( + then, + new_aggregates, + aggregate_map, + counter, + )?; + new_when_then.push((new_when, new_then)); + } + + let new_else = if let Some(e) = else_expr { + Some(Box::new(Self::replace_aggregates_with_columns_dedup( + *e, + new_aggregates, + aggregate_map, + counter, + )?)) + } else { + None + }; + + Ok(LogicalExpr::Case { + expr: new_case_expr, + when_then: new_when_then, + else_expr: new_else, + }) + } + LogicalExpr::Alias { expr, alias } => { + let new_expr = Self::replace_aggregates_with_columns_dedup( + *expr, + new_aggregates, + aggregate_map, + counter, + )?; + Ok(LogicalExpr::Alias { + expr: Box::new(new_expr), + alias, + }) + } + // Other expressions - keep as is + _ => Ok(expr), + } + } + + // Get column name from expression + fn expr_to_column_name(expr: &ast::Expr) -> String { + match expr { + ast::Expr::Id(name) => Self::name_to_string(name), + ast::Expr::Qualified(_, col) => Self::name_to_string(col), + ast::Expr::FunctionCall { name, .. } => Self::name_to_string(name), + ast::Expr::FunctionCallStar { name, .. } => { + format!("{}(*)", Self::name_to_string(name)) + } + _ => "expr".to_string(), + } + } + + // Get table schema + fn get_table_schema(&self, table_name: &str) -> Result { + // Look up table in schema + let table = self + .schema + .get_table(table_name) + .ok_or_else(|| LimboError::ParseError(format!("Table '{table_name}' not found")))?; + + let mut columns = Vec::new(); + for col in table.columns() { + if let Some(ref name) = col.name { + columns.push((name.clone(), col.ty)); + } + } + + Ok(Arc::new(LogicalSchema::new(columns))) + } + + // Infer expression type + fn infer_expr_type(expr: &LogicalExpr, schema: &SchemaRef) -> Result { + match expr { + LogicalExpr::Column(col) => { + if let Some((_, typ)) = schema.find_column(&col.name) { + Ok(*typ) + } else { + Ok(Type::Text) + } + } + LogicalExpr::Literal(Value::Integer(_)) => Ok(Type::Integer), + LogicalExpr::Literal(Value::Float(_)) => Ok(Type::Real), + LogicalExpr::Literal(Value::Text(_)) => Ok(Type::Text), + LogicalExpr::Literal(Value::Null) => Ok(Type::Null), + LogicalExpr::Literal(Value::Blob(_)) => Ok(Type::Blob), + LogicalExpr::BinaryExpr { op, left, right } => { + match op { + ast::Operator::Add | ast::Operator::Subtract | ast::Operator::Multiply => { + // Infer types of operands to match SQLite/Numeric behavior + let left_type = Self::infer_expr_type(left, schema)?; + let right_type = Self::infer_expr_type(right, schema)?; + + // Integer op Integer = Integer (matching core/numeric/mod.rs behavior) + // Any operation with Real = Real + match (left_type, right_type) { + (Type::Integer, Type::Integer) => Ok(Type::Integer), + (Type::Integer, Type::Real) + | (Type::Real, Type::Integer) + | (Type::Real, Type::Real) => Ok(Type::Real), + (Type::Null, _) | (_, Type::Null) => Ok(Type::Null), + // For Text/Blob, SQLite coerces to numeric, defaulting to Real + _ => Ok(Type::Real), + } + } + ast::Operator::Divide => { + // Division always produces Real in SQLite + Ok(Type::Real) + } + ast::Operator::Modulus => { + // Modulus follows same rules as other arithmetic ops + let left_type = Self::infer_expr_type(left, schema)?; + let right_type = Self::infer_expr_type(right, schema)?; + match (left_type, right_type) { + (Type::Integer, Type::Integer) => Ok(Type::Integer), + _ => Ok(Type::Real), + } + } + ast::Operator::Equals + | ast::Operator::NotEquals + | ast::Operator::Less + | ast::Operator::LessEquals + | ast::Operator::Greater + | ast::Operator::GreaterEquals + | ast::Operator::And + | ast::Operator::Or + | ast::Operator::Is + | ast::Operator::IsNot => Ok(Type::Integer), + ast::Operator::Concat => Ok(Type::Text), + _ => Ok(Type::Text), // Default for other operators + } + } + LogicalExpr::UnaryExpr { op, expr } => match op { + ast::UnaryOperator::Not => Ok(Type::Integer), + ast::UnaryOperator::Negative | ast::UnaryOperator::Positive => { + Self::infer_expr_type(expr, schema) + } + ast::UnaryOperator::BitwiseNot => Ok(Type::Integer), + }, + LogicalExpr::AggregateFunction { fun, .. } => match fun { + AggFunc::Count | AggFunc::Count0 => Ok(Type::Integer), + AggFunc::Sum | AggFunc::Avg | AggFunc::Total => Ok(Type::Real), + AggFunc::Min | AggFunc::Max => Ok(Type::Text), + AggFunc::GroupConcat | AggFunc::StringAgg => Ok(Type::Text), + #[cfg(feature = "json")] + AggFunc::JsonbGroupArray + | AggFunc::JsonGroupArray + | AggFunc::JsonbGroupObject + | AggFunc::JsonGroupObject => Ok(Type::Text), + AggFunc::External(_) => Ok(Type::Text), // Default for external + }, + LogicalExpr::Alias { expr, .. } => Self::infer_expr_type(expr, schema), + LogicalExpr::IsNull { .. } => Ok(Type::Integer), + LogicalExpr::InList { .. } | LogicalExpr::InSubquery { .. } => Ok(Type::Integer), + LogicalExpr::Exists { .. } => Ok(Type::Integer), + LogicalExpr::Between { .. } => Ok(Type::Integer), + LogicalExpr::Like { .. } => Ok(Type::Integer), + _ => Ok(Type::Text), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type}; + use turso_parser::parser::Parser; + + fn create_test_schema() -> Schema { + let mut schema = Schema::new(false); + + // Create users table + let users_table = BTreeTable { + name: "users".to_string(), + root_page: 2, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("age".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("email".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + schema.add_btree_table(Arc::new(users_table)); + + // Create orders table + let orders_table = BTreeTable { + name: "orders".to_string(), + root_page: 3, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("user_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("amount".to_string()), + ty: Type::Real, + ty_str: "REAL".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: None, + }; + schema.add_btree_table(Arc::new(orders_table)); + + schema + } + + fn parse_and_build(sql: &str, schema: &Schema) -> Result { + let mut parser = Parser::new(sql.as_bytes()); + let cmd = parser + .next() + .ok_or_else(|| LimboError::ParseError("Empty statement".to_string()))? + .map_err(|e| LimboError::ParseError(e.to_string()))?; + match cmd { + ast::Cmd::Stmt(stmt) => { + let mut builder = LogicalPlanBuilder::new(schema); + builder.build_statement(&stmt) + } + _ => Err(LimboError::ParseError( + "Only SQL statements are supported".to_string(), + )), + } + } + + #[test] + fn test_simple_select() { + let schema = create_test_schema(); + let sql = "SELECT id, name FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 2); + assert!(matches!(proj.exprs[0], LogicalExpr::Column(_))); + assert!(matches!(proj.exprs[1], LogicalExpr::Column(_))); + + match &*proj.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_select_with_filter() { + let schema = create_test_schema(); + let sql = "SELECT name FROM users WHERE age > 18"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + + match &*proj.input { + LogicalPlan::Filter(filter) => { + assert!(matches!( + filter.predicate, + LogicalExpr::BinaryExpr { + op: ast::Operator::Greater, + .. + } + )); + + match &*filter.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan"), + } + } + _ => panic!("Expected Filter"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_aggregate_with_group_by() { + let schema = create_test_schema(); + let sql = "SELECT user_id, SUM(amount) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + assert_eq!(agg.aggr_expr.len(), 1); + assert_eq!(agg.schema.column_count(), 2); + + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Sum, + .. + } + )); + + match &*agg.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "orders"); + } + _ => panic!("Expected TableScan"), + } + } + _ => panic!("Expected Aggregate (no projection)"), + } + } + + #[test] + fn test_aggregate_without_group_by() { + let schema = create_test_schema(); + let sql = "SELECT COUNT(*), MAX(age) FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 0); + assert_eq!(agg.aggr_expr.len(), 2); + assert_eq!(agg.schema.column_count(), 2); + + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Count, + .. + } + )); + + assert!(matches!( + agg.aggr_expr[1], + LogicalExpr::AggregateFunction { + fun: AggFunc::Max, + .. + } + )); + } + _ => panic!("Expected Aggregate (no projection)"), + } + } + + #[test] + fn test_order_by() { + let schema = create_test_schema(); + let sql = "SELECT name FROM users ORDER BY age DESC, name ASC"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 2); + assert!(!sort.exprs[0].asc); // DESC + assert!(sort.exprs[1].asc); // ASC + + match &*sort.input { + LogicalPlan::Projection(_) => {} + _ => panic!("Expected Projection"), + } + } + _ => panic!("Expected Sort"), + } + } + + #[test] + fn test_limit_offset() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users LIMIT 10 OFFSET 5"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Limit(limit) => { + assert_eq!(limit.fetch, Some(10)); + assert_eq!(limit.skip, Some(5)); + } + _ => panic!("Expected Limit"), + } + } + + #[test] + fn test_order_by_with_limit() { + let schema = create_test_schema(); + let sql = "SELECT name FROM users ORDER BY age DESC LIMIT 5"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // Should produce: Limit -> Sort -> Projection -> TableScan + match plan { + LogicalPlan::Limit(limit) => { + assert_eq!(limit.fetch, Some(5)); + assert_eq!(limit.skip, None); + + match &*limit.input { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 1); + assert!(!sort.exprs[0].asc); // DESC + + match &*sort.input { + LogicalPlan::Projection(_) => {} + _ => panic!("Expected Projection under Sort"), + } + } + _ => panic!("Expected Sort under Limit"), + } + } + _ => panic!("Expected Limit at top level"), + } + } + + #[test] + fn test_distinct() { + let schema = create_test_schema(); + let sql = "SELECT DISTINCT name FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Distinct(distinct) => match &*distinct.input { + LogicalPlan::Projection(_) => {} + _ => panic!("Expected Projection"), + }, + _ => panic!("Expected Distinct"), + } + } + + #[test] + fn test_union() { + let schema = create_test_schema(); + let sql = "SELECT id FROM users UNION SELECT user_id FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Union(union) => { + assert!(!union.all); + assert_eq!(union.inputs.len(), 2); + } + _ => panic!("Expected Union"), + } + } + + #[test] + fn test_union_all() { + let schema = create_test_schema(); + let sql = "SELECT id FROM users UNION ALL SELECT user_id FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Union(union) => { + assert!(union.all); + assert_eq!(union.inputs.len(), 2); + } + _ => panic!("Expected Union"), + } + } + + #[test] + fn test_union_with_order_by() { + let schema = create_test_schema(); + let sql = "SELECT id, name FROM users UNION SELECT user_id, name FROM orders ORDER BY id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 1); + assert!(sort.exprs[0].asc); // Default ASC + + match &*sort.input { + LogicalPlan::Union(union) => { + assert!(!union.all); // UNION (not UNION ALL) + assert_eq!(union.inputs.len(), 2); + } + _ => panic!("Expected Union under Sort"), + } + } + _ => panic!("Expected Sort at top level"), + } + } + + #[test] + fn test_with_cte() { + let schema = create_test_schema(); + let sql = "WITH active_users AS (SELECT * FROM users WHERE age > 18) SELECT name FROM active_users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::WithCTE(with) => { + assert_eq!(with.ctes.len(), 1); + assert!(with.ctes.contains_key("active_users")); + + let cte = &with.ctes["active_users"]; + match &**cte { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(_) => {} + _ => panic!("Expected Filter in CTE"), + }, + _ => panic!("Expected Projection in CTE"), + } + + match &*with.body { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::CTERef(cte_ref) => { + assert_eq!(cte_ref.name, "active_users"); + } + _ => panic!("Expected CTERef"), + }, + _ => panic!("Expected Projection in body"), + } + } + _ => panic!("Expected WithCTE"), + } + } + + #[test] + fn test_case_expression() { + let schema = create_test_schema(); + let sql = "SELECT CASE WHEN age < 18 THEN 'minor' WHEN age < 65 THEN 'adult' ELSE 'senior' END FROM users"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + assert!(matches!(proj.exprs[0], LogicalExpr::Case { .. })); + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_in_list() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE id IN (1, 2, 3)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::InList { list, negated, .. } => { + assert!(!negated); + assert_eq!(list.len(), 3); + } + _ => panic!("Expected InList"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_in_subquery() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => { + assert!(matches!(filter.predicate, LogicalExpr::InSubquery { .. })); + } + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_exists_subquery() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => { + assert!(matches!(filter.predicate, LogicalExpr::Exists { .. })); + } + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_between() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE age BETWEEN 18 AND 65"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::Between { negated, .. } => { + assert!(!negated); + } + _ => panic!("Expected Between"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_like() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE name LIKE 'John%'"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::Like { + negated, escape, .. + } => { + assert!(!negated); + assert!(escape.is_none()); + } + _ => panic!("Expected Like"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_is_null() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE email IS NULL"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::IsNull { negated, .. } => { + assert!(!negated); + } + _ => panic!("Expected IsNull, got: {:?}", filter.predicate), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_is_not_null() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users WHERE email IS NOT NULL"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Filter(filter) => match &filter.predicate { + LogicalExpr::IsNull { negated, .. } => { + assert!(negated); + } + _ => panic!("Expected IsNull"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_values_clause() { + let schema = create_test_schema(); + let sql = "SELECT * FROM (VALUES (1, 'a'), (2, 'b'))"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Values(values) => { + assert_eq!(values.rows.len(), 2); + assert_eq!(values.rows[0].len(), 2); + } + _ => panic!("Expected Values"), + }, + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_complex_expression_with_aggregation() { + // Test: SELECT sum(id + 2) * 2 FROM orders GROUP BY user_id + let schema = create_test_schema(); + + // Test the complex case: sum((id + 2)) * 2 with parentheses + let sql = "SELECT sum((id + 2)) * 2 FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::BinaryExpr { left, op, right } => { + assert_eq!(*op, BinaryOperator::Multiply); + assert!(matches!(**left, LogicalExpr::Column(_))); + assert!(matches!(**right, LogicalExpr::Literal(_))); + } + _ => panic!("Expected BinaryExpr in projection"), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + + assert_eq!(agg.aggr_expr.len(), 1); + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { fun, args, .. } => { + assert_eq!(*fun, AggregateFunction::Sum); + assert_eq!(args.len(), 1); + match &args[0] { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_arg_proj_")); + } + _ => panic!("Expected Column reference to projected expression in aggregate args, got {:?}", args[0]), + } + } + _ => panic!("Expected AggregateFunction"), + } + + match &*agg.input { + LogicalPlan::Projection(inner_proj) => { + assert!(inner_proj.exprs.len() >= 2); + let has_binary_add = inner_proj.exprs.iter().any(|e| { + matches!( + e, + LogicalExpr::BinaryExpr { + op: BinaryOperator::Add, + .. + } + ) + }); + assert!( + has_binary_add, + "Should have id + 2 expression in inner projection" + ); + } + _ => panic!("Expected Projection as input to Aggregate"), + } + } + _ => panic!("Expected Aggregate under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_function_on_aggregate_result() { + let schema = create_test_schema(); + + let sql = "SELECT abs(sum(id)) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun, "abs"); + assert_eq!(args.len(), 1); + assert!(matches!(args[0], LogicalExpr::Column(_))); + } + _ => panic!("Expected ScalarFunction in projection"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_multiple_aggregates_with_arithmetic() { + let schema = create_test_schema(); + + let sql = "SELECT sum(id) * 2 + count(*) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::BinaryExpr { op, .. } => { + assert_eq!(*op, BinaryOperator::Add); + } + _ => panic!("Expected BinaryExpr"), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 2); + } + _ => panic!("Expected Aggregate"), + } + } + _ => panic!("Expected Projection"), + } + } + + #[test] + fn test_projection_aggregation_projection() { + let schema = create_test_schema(); + + // This tests: projection -> aggregation -> projection + // The inner projection computes (id + 2), then we aggregate sum(), then apply abs() + let sql = "SELECT abs(sum(id + 2)) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // Should produce: Projection(abs) -> Aggregate(sum) -> Projection(id + 2) -> TableScan + match plan { + LogicalPlan::Projection(outer_proj) => { + assert_eq!(outer_proj.exprs.len(), 1); + + // Outer projection should apply abs() function + match &outer_proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun, "abs"); + assert_eq!(args.len(), 1); + assert!(matches!(args[0], LogicalExpr::Column(_))); + } + _ => panic!("Expected abs() function in outer projection"), + } + + // Next should be the Aggregate + match &*outer_proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + assert_eq!(agg.aggr_expr.len(), 1); + + // The aggregate should be summing a column reference + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { fun, args, .. } => { + assert_eq!(*fun, AggregateFunction::Sum); + assert_eq!(args.len(), 1); + + // Should reference the projected column + match &args[0] { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_arg_proj_")); + } + _ => panic!("Expected column reference in aggregate"), + } + } + _ => panic!("Expected AggregateFunction"), + } + + // Input to aggregate should be a projection computing id + 2 + match &*agg.input { + LogicalPlan::Projection(inner_proj) => { + // Should have at least the group column and the computed expression + assert!(inner_proj.exprs.len() >= 2); + + // Check for the id + 2 expression + let has_add_expr = inner_proj.exprs.iter().any(|e| { + matches!( + e, + LogicalExpr::BinaryExpr { + op: BinaryOperator::Add, + .. + } + ) + }); + assert!( + has_add_expr, + "Should have id + 2 expression in inner projection" + ); + } + _ => panic!("Expected inner Projection under Aggregate"), + } + } + _ => panic!("Expected Aggregate under outer Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_group_by_validation_allow_grouped_column() { + let schema = create_test_schema(); + + // Test that grouped columns are allowed + let sql = "SELECT user_id, COUNT(*) FROM orders GROUP BY user_id"; + let result = parse_and_build(sql, &schema); + + assert!(result.is_ok(), "Should allow grouped column in SELECT"); + } + + #[test] + fn test_group_by_validation_allow_constants() { + let schema = create_test_schema(); + + // Test that simple constants are allowed even when not grouped + let sql = "SELECT user_id, 42, COUNT(*) FROM orders GROUP BY user_id"; + let result = parse_and_build(sql, &schema); + + assert!( + result.is_ok(), + "Should allow simple constants in SELECT with GROUP BY" + ); + + let sql_complex = "SELECT user_id, (100 + 50) * 2, COUNT(*) FROM orders GROUP BY user_id"; + let result_complex = parse_and_build(sql_complex, &schema); + + assert!( + result_complex.is_ok(), + "Should allow complex constant expressions in SELECT with GROUP BY" + ); + } + + #[test] + fn test_parenthesized_aggregate_expressions() { + let schema = create_test_schema(); + + let sql = "SELECT 25, (MAX(id) / 3), 39 FROM orders"; + let result = parse_and_build(sql, &schema); + + assert!( + result.is_ok(), + "Should handle parenthesized aggregate expressions" + ); + + let plan = result.unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 3); + + assert!(matches!( + proj.exprs[0], + LogicalExpr::Literal(Value::Integer(25)) + )); + + match &proj.exprs[1] { + LogicalExpr::BinaryExpr { left, op, right } => { + assert_eq!(*op, BinaryOperator::Divide); + assert!(matches!(&**left, LogicalExpr::Column(_))); + assert!(matches!(&**right, LogicalExpr::Literal(Value::Integer(3)))); + } + _ => panic!("Expected BinaryExpr for (MAX(id) / 3)"), + } + + assert!(matches!( + proj.exprs[2], + LogicalExpr::Literal(Value::Integer(39)) + )); + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 1); + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Max, + .. + } + )); + } + _ => panic!("Expected Aggregate node under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_duplicate_aggregate_reuse() { + let schema = create_test_schema(); + + let sql = "SELECT (COUNT(*) - 225), 30, COUNT(*) FROM orders"; + let result = parse_and_build(sql, &schema); + + assert!(result.is_ok(), "Should handle duplicate aggregates"); + + let plan = result.unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 3); + + match &proj.exprs[0] { + LogicalExpr::BinaryExpr { left, op, right } => { + assert_eq!(*op, BinaryOperator::Subtract); + match &**left { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_") || col.name == "COUNT(*)"); + } + _ => panic!("Expected Column reference for COUNT(*)"), + } + assert!(matches!( + &**right, + LogicalExpr::Literal(Value::Integer(225)) + )); + } + _ => panic!("Expected BinaryExpr for (COUNT(*) - 225)"), + } + + assert!(matches!( + proj.exprs[1], + LogicalExpr::Literal(Value::Integer(30)) + )); + + match &proj.exprs[2] { + LogicalExpr::Column(col) => { + assert!(col.name.starts_with("__agg_") || col.name == "COUNT(*)"); + } + _ => panic!("Expected Column reference for COUNT(*)"), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!( + agg.aggr_expr.len(), + 1, + "Should have only one COUNT(*) aggregate" + ); + assert!(matches!( + agg.aggr_expr[0], + LogicalExpr::AggregateFunction { + fun: AggFunc::Count, + .. + } + )); + } + _ => panic!("Expected Aggregate node under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_aggregate_without_group_by_allow_constants() { + let schema = create_test_schema(); + + // Test that constants are allowed with aggregates even without GROUP BY + let sql = "SELECT 42, COUNT(*), MAX(amount) FROM orders"; + let result = parse_and_build(sql, &schema); + + assert!( + result.is_ok(), + "Should allow simple constants with aggregates without GROUP BY" + ); + + // Test complex constant expressions + let sql_complex = "SELECT (9 / 6) % 5, COUNT(*), MAX(amount) FROM orders"; + let result_complex = parse_and_build(sql_complex, &schema); + + assert!( + result_complex.is_ok(), + "Should allow complex constant expressions with aggregates without GROUP BY" + ); + } + + #[test] + fn test_aggregate_without_group_by_creates_aggregate_node() { + let schema = create_test_schema(); + + // Test that aggregate without GROUP BY creates proper Aggregate node + let sql = "SELECT MAX(amount) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // Should be: Aggregate -> TableScan (no projection needed for simple aggregate) + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 0, "Should have no group expressions"); + assert_eq!( + agg.aggr_expr.len(), + 1, + "Should have one aggregate expression" + ); + assert_eq!( + agg.schema.column_count(), + 1, + "Schema should have one column" + ); + } + _ => panic!("Expected Aggregate at top level (no projection)"), + } + } + + #[test] + fn test_scalar_vs_aggregate_function_classification() { + let schema = create_test_schema(); + + // Test MIN/MAX with 1 argument - should be aggregate + let sql = "SELECT MIN(amount) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 1, "MIN(x) should be an aggregate"); + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { fun, args, .. } => { + assert!(matches!(fun, AggFunc::Min)); + assert_eq!(args.len(), 1); + } + _ => panic!("Expected AggregateFunction"), + } + } + _ => panic!("Expected Aggregate node for MIN(x)"), + } + + // Test MIN/MAX with 2 arguments - should be scalar in projection + let sql = "SELECT MIN(amount, user_id) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1, "Should have one projection expression"); + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!( + fun.to_lowercase(), + "min", + "MIN(x,y) should be a scalar function" + ); + assert_eq!(args.len(), 2); + } + _ => panic!("Expected ScalarFunction for MIN(x,y)"), + } + } + _ => panic!("Expected Projection node for scalar MIN(x,y)"), + } + + // Test MAX with 3 arguments - should be scalar + let sql = "SELECT MAX(amount, user_id, id) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1); + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!( + fun.to_lowercase(), + "max", + "MAX(x,y,z) should be a scalar function" + ); + assert_eq!(args.len(), 3); + } + _ => panic!("Expected ScalarFunction for MAX(x,y,z)"), + } + } + _ => panic!("Expected Projection node for scalar MAX(x,y,z)"), + } + + // Test that MIN with 0 args is treated as scalar (will fail later in execution) + let sql = "SELECT MIN() FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun.to_lowercase(), "min"); + assert_eq!(args.len(), 0, "MIN() should be scalar with 0 args"); + } + _ => panic!("Expected ScalarFunction for MIN()"), + }, + _ => panic!("Expected Projection for MIN()"), + } + + // Test other functions that are always aggregate (COUNT, SUM, AVG) + let sql = "SELECT COUNT(*), SUM(amount), AVG(amount) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.aggr_expr.len(), 3, "Should have 3 aggregate functions"); + for expr in &agg.aggr_expr { + assert!(matches!(expr, LogicalExpr::AggregateFunction { .. })); + } + } + _ => panic!("Expected Aggregate node"), + } + + // Test scalar functions that are never aggregates (ABS, ROUND, etc.) + let sql = "SELECT ABS(amount), ROUND(amount), LENGTH(product) FROM orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 3, "Should have 3 scalar functions"); + for expr in &proj.exprs { + match expr { + LogicalExpr::ScalarFunction { .. } => {} + _ => panic!("Expected all ScalarFunctions"), + } + } + } + _ => panic!("Expected Projection node for scalar functions"), + } + } + + #[test] + fn test_mixed_aggregate_and_group_columns() { + let schema = create_test_schema(); + + // When selecting both aggregate and grouping columns + let sql = "SELECT user_id, sum(id) FROM orders GROUP BY user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + // No projection needed - aggregate outputs exactly what we select + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); + assert_eq!(agg.aggr_expr.len(), 1); + assert_eq!(agg.schema.column_count(), 2); + } + _ => panic!("Expected Aggregate (no projection)"), + } + } + + #[test] + fn test_scalar_function_wrapping_aggregate_no_group_by() { + // Test: SELECT HEX(SUM(age + 2)) FROM users + // Expected structure: + // Projection { exprs: [ScalarFunction(HEX, [Column])] } + // -> Aggregate { aggr_expr: [Sum(BinaryExpr(age + 2))], group_expr: [] } + // -> Projection { exprs: [BinaryExpr(age + 2)] } + // -> TableScan("users") + + let schema = create_test_schema(); + let sql = "SELECT HEX(SUM(age + 2)) FROM users"; + let mut parser = Parser::new(sql.as_bytes()); + let stmt = parser.next().unwrap().unwrap(); + + let plan = match stmt { + ast::Cmd::Stmt(stmt) => { + let mut builder = LogicalPlanBuilder::new(&schema); + builder.build_statement(&stmt).unwrap() + } + _ => panic!("Expected SQL statement"), + }; + + match &plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 1, "Should have one expression"); + + match &proj.exprs[0] { + LogicalExpr::ScalarFunction { fun, args } => { + assert_eq!(fun, "HEX", "Outer function should be HEX"); + assert_eq!(args.len(), 1, "HEX should have one argument"); + + match &args[0] { + LogicalExpr::Column(_) => {} + LogicalExpr::AggregateFunction { .. } => { + panic!("Aggregate function should not be embedded in projection! It should be in a separate Aggregate operator"); + } + _ => panic!( + "Expected column reference as argument to HEX, got: {:?}", + args[0] + ), + } + } + _ => panic!("Expected ScalarFunction (HEX), got: {:?}", proj.exprs[0]), + } + + match &*proj.input { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 0, "Should have no GROUP BY"); + assert_eq!( + agg.aggr_expr.len(), + 1, + "Should have one aggregate expression" + ); + + match &agg.aggr_expr[0] { + LogicalExpr::AggregateFunction { + fun, + args, + distinct, + } => { + assert_eq!(*fun, crate::function::AggFunc::Sum, "Should be SUM"); + assert!(!distinct, "Should not be DISTINCT"); + assert_eq!(args.len(), 1, "SUM should have one argument"); + + match &args[0] { + LogicalExpr::Column(col) => { + // When aggregate arguments are complex, they get pre-projected + assert!(col.name.starts_with("__agg_arg_proj_"), + "Should reference pre-projected column, got: {}", col.name); + } + LogicalExpr::BinaryExpr { left, op, right } => { + // Simple case without pre-projection (shouldn't happen with current implementation) + assert_eq!(*op, ast::Operator::Add, "Should be addition"); + + match (&**left, &**right) { + (LogicalExpr::Column(col), LogicalExpr::Literal(val)) => { + assert_eq!(col.name, "age", "Should reference age column"); + assert_eq!(*val, Value::Integer(2), "Should add 2"); + } + _ => panic!("Expected age + 2"), + } + } + _ => panic!("Expected Column reference or BinaryExpr for aggregate argument, got: {:?}", args[0]), + } + } + _ => panic!("Expected AggregateFunction"), + } + + match &*agg.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan under projection"), + }, + _ => panic!("Expected TableScan or Projection under Aggregate"), + } + } + _ => panic!( + "Expected Aggregate operator under Projection, got: {:?}", + proj.input + ), + } + } + _ => panic!("Expected Projection as top-level operator, got: {plan:?}"), + } + } +} diff --git a/core/translate/mod.rs b/core/translate/mod.rs index 7d29a8173..2de196a8f 100644 --- a/core/translate/mod.rs +++ b/core/translate/mod.rs @@ -21,6 +21,8 @@ pub(crate) mod group_by; pub(crate) mod index; pub(crate) mod insert; pub(crate) mod integrity_check; +#[cfg(test)] +pub(crate) mod logical; pub(crate) mod main_loop; pub(crate) mod optimizer; pub(crate) mod order_by;