Files
turso/core/translate/logical.rs
Pavan-Nambi 36bf88119f add tests
clippy

expect err to make clippy happy

cleanup
2025-10-12 16:38:12 +05:30

4212 lines
157 KiB
Rust

//! Logical plan representation for SQL queries
//!
//! This module provides a platform-independent intermediate representation
//! for SQL queries. The logical plan is a DAG (Directed Acyclic Graph) that
//! supports CTEs and can be used for query optimization before being compiled
//! to an execution plan (e.g., DBSP circuits).
//!
//! The main entry point is `LogicalPlanBuilder` which constructs logical plans
//! from SQL AST nodes.
use crate::function::AggFunc;
use crate::schema::{Schema, Type};
use crate::types::Value;
use crate::{LimboError, Result};
use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
use std::sync::Arc;
use turso_macros::match_ignore_ascii_case;
use turso_parser::ast;
/// Result type for preprocessing aggregate expressions
type PreprocessAggregateResult = (
bool, // needs_pre_projection
Vec<LogicalExpr>, // pre_projection_exprs
Vec<ColumnInfo>, // pre_projection_schema
Vec<LogicalExpr>, // modified_aggr_exprs
Vec<LogicalExpr>, // modified_group_exprs
);
/// Result type for parsing join conditions
type JoinConditionsResult = (Vec<(LogicalExpr, LogicalExpr)>, Option<LogicalExpr>);
/// Information about a column in a logical schema
#[derive(Debug, Clone, PartialEq)]
pub struct ColumnInfo {
pub name: String,
pub ty: Type,
pub database: Option<String>,
pub table: Option<String>,
pub table_alias: Option<String>,
}
/// Schema information for logical plan nodes
#[derive(Debug, Clone, PartialEq)]
pub struct LogicalSchema {
pub columns: Vec<ColumnInfo>,
}
/// A reference to a schema that can be shared between nodes
pub type SchemaRef = Arc<LogicalSchema>;
impl LogicalSchema {
pub fn new(columns: Vec<ColumnInfo>) -> Self {
Self { columns }
}
pub fn empty() -> Self {
Self {
columns: Vec::new(),
}
}
pub fn column_count(&self) -> usize {
self.columns.len()
}
pub fn find_column(&self, name: &str, table: Option<&str>) -> Option<(usize, &ColumnInfo)> {
if let Some(table_ref) = table {
// Check if it's a database.table format
if table_ref.contains('.') {
let parts: Vec<&str> = table_ref.splitn(2, '.').collect();
if parts.len() == 2 {
let db = parts[0];
let tbl = parts[1];
return self
.columns
.iter()
.position(|c| {
c.name == name
&& c.database.as_deref() == Some(db)
&& c.table.as_deref() == Some(tbl)
})
.map(|idx| (idx, &self.columns[idx]));
}
}
// Try to match against table alias first, then table name
self.columns
.iter()
.position(|c| {
c.name == name
&& (c.table_alias.as_deref() == Some(table_ref)
|| c.table.as_deref() == Some(table_ref))
})
.map(|idx| (idx, &self.columns[idx]))
} else {
// Unqualified lookup - just match by name
self.columns
.iter()
.position(|c| c.name == name)
.map(|idx| (idx, &self.columns[idx]))
}
}
}
/// Logical representation of a SQL query plan
#[derive(Debug, Clone, PartialEq)]
pub enum LogicalPlan {
/// Projection - SELECT expressions
Projection(Projection),
/// Filter - WHERE/HAVING clause
Filter(Filter),
/// Aggregate - GROUP BY with aggregate functions
Aggregate(Aggregate),
/// Join - combining two relations
Join(Join),
/// Sort - ORDER BY clause
Sort(Sort),
/// Limit - LIMIT/OFFSET clause
Limit(Limit),
/// Table scan - reading from a base table
TableScan(TableScan),
/// Union - UNION/UNION ALL/INTERSECT/EXCEPT
Union(Union),
/// Distinct - remove duplicates
Distinct(Distinct),
/// Empty relation - no rows
EmptyRelation(EmptyRelation),
/// Values - literal rows (VALUES clause)
Values(Values),
/// CTE support - WITH clause
WithCTE(WithCTE),
/// Reference to a CTE
CTERef(CTERef),
}
impl LogicalPlan {
/// Get the schema of this plan node
pub fn schema(&self) -> &SchemaRef {
match self {
LogicalPlan::Projection(p) => &p.schema,
LogicalPlan::Filter(f) => f.input.schema(),
LogicalPlan::Aggregate(a) => &a.schema,
LogicalPlan::Join(j) => &j.schema,
LogicalPlan::Sort(s) => s.input.schema(),
LogicalPlan::Limit(l) => l.input.schema(),
LogicalPlan::TableScan(t) => &t.schema,
LogicalPlan::Union(u) => &u.schema,
LogicalPlan::Distinct(d) => d.input.schema(),
LogicalPlan::EmptyRelation(e) => &e.schema,
LogicalPlan::Values(v) => &v.schema,
LogicalPlan::WithCTE(w) => w.body.schema(),
LogicalPlan::CTERef(c) => &c.schema,
}
}
}
/// Projection operator - SELECT expressions
#[derive(Debug, Clone, PartialEq)]
pub struct Projection {
pub input: Arc<LogicalPlan>,
pub exprs: Vec<LogicalExpr>,
pub schema: SchemaRef,
}
/// Filter operator - WHERE/HAVING predicates
#[derive(Debug, Clone, PartialEq)]
pub struct Filter {
pub input: Arc<LogicalPlan>,
pub predicate: LogicalExpr,
}
/// Aggregate operator - GROUP BY with aggregations
#[derive(Debug, Clone, PartialEq)]
pub struct Aggregate {
pub input: Arc<LogicalPlan>,
pub group_expr: Vec<LogicalExpr>,
pub aggr_expr: Vec<LogicalExpr>,
pub schema: SchemaRef,
}
/// Types of joins
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum JoinType {
Inner,
Left,
Right,
Full,
Cross,
}
/// Join operator - combines two relations
#[derive(Debug, Clone, PartialEq)]
pub struct Join {
pub left: Arc<LogicalPlan>,
pub right: Arc<LogicalPlan>,
pub join_type: JoinType,
pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions (left_expr, right_expr)
pub filter: Option<LogicalExpr>, // Additional filter conditions
pub schema: SchemaRef,
}
/// Sort operator - ORDER BY
#[derive(Debug, Clone, PartialEq)]
pub struct Sort {
pub input: Arc<LogicalPlan>,
pub exprs: Vec<SortExpr>,
}
/// Sort expression with direction
#[derive(Debug, Clone, PartialEq)]
pub struct SortExpr {
pub expr: LogicalExpr,
pub asc: bool,
pub nulls_first: bool,
}
/// Limit operator - LIMIT/OFFSET
#[derive(Debug, Clone, PartialEq)]
pub struct Limit {
pub input: Arc<LogicalPlan>,
pub skip: Option<usize>,
pub fetch: Option<usize>,
}
/// Table scan operator
#[derive(Debug, Clone, PartialEq)]
pub struct TableScan {
pub table_name: String,
pub alias: Option<String>,
pub schema: SchemaRef,
pub projection: Option<Vec<usize>>, // Column indices to project
}
/// Union operator
#[derive(Debug, Clone, PartialEq)]
pub struct Union {
pub inputs: Vec<Arc<LogicalPlan>>,
pub all: bool, // true for UNION ALL, false for UNION
pub schema: SchemaRef,
}
/// Distinct operator
#[derive(Debug, Clone, PartialEq)]
pub struct Distinct {
pub input: Arc<LogicalPlan>,
}
/// Empty relation - produces no rows
#[derive(Debug, Clone, PartialEq)]
pub struct EmptyRelation {
pub produce_one_row: bool,
pub schema: SchemaRef,
}
/// Values operator - literal rows
#[derive(Debug, Clone, PartialEq)]
pub struct Values {
pub rows: Vec<Vec<LogicalExpr>>,
pub schema: SchemaRef,
}
/// WITH clause - CTEs
#[derive(Debug, Clone, PartialEq)]
pub struct WithCTE {
pub ctes: HashMap<String, Arc<LogicalPlan>>,
pub body: Arc<LogicalPlan>,
}
/// Reference to a CTE
#[derive(Debug, Clone, PartialEq)]
pub struct CTERef {
pub name: String,
pub schema: SchemaRef,
}
/// Logical expression representation
#[derive(Debug, Clone, PartialEq)]
pub enum LogicalExpr {
/// Column reference
Column(Column),
/// Literal value
Literal(Value),
/// Binary expression
BinaryExpr {
left: Box<LogicalExpr>,
op: BinaryOperator,
right: Box<LogicalExpr>,
},
/// Unary expression
UnaryExpr {
op: UnaryOperator,
expr: Box<LogicalExpr>,
},
/// Aggregate function
AggregateFunction {
fun: AggregateFunction,
args: Vec<LogicalExpr>,
distinct: bool,
},
/// Scalar function call
ScalarFunction { fun: String, args: Vec<LogicalExpr> },
/// CASE expression
Case {
expr: Option<Box<LogicalExpr>>,
when_then: Vec<(LogicalExpr, LogicalExpr)>,
else_expr: Option<Box<LogicalExpr>>,
},
/// IN list
InList {
expr: Box<LogicalExpr>,
list: Vec<LogicalExpr>,
negated: bool,
},
/// IN subquery
InSubquery {
expr: Box<LogicalExpr>,
subquery: Arc<LogicalPlan>,
negated: bool,
},
/// EXISTS subquery
Exists {
subquery: Arc<LogicalPlan>,
negated: bool,
},
/// Scalar subquery
ScalarSubquery(Arc<LogicalPlan>),
/// Alias for an expression
Alias {
expr: Box<LogicalExpr>,
alias: String,
},
/// IS NULL / IS NOT NULL
IsNull {
expr: Box<LogicalExpr>,
negated: bool,
},
/// BETWEEN
Between {
expr: Box<LogicalExpr>,
low: Box<LogicalExpr>,
high: Box<LogicalExpr>,
negated: bool,
},
/// LIKE pattern matching
Like {
expr: Box<LogicalExpr>,
pattern: Box<LogicalExpr>,
escape: Option<char>,
negated: bool,
},
/// CAST expression
Cast {
expr: Box<LogicalExpr>,
type_name: Option<ast::Type>,
},
}
/// Column reference
#[derive(Debug, Clone, PartialEq)]
pub struct Column {
pub name: String,
pub table: Option<String>,
}
impl Column {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
table: None,
}
}
pub fn with_table(name: impl Into<String>, table: impl Into<String>) -> Self {
Self {
name: name.into(),
table: Some(table.into()),
}
}
}
impl Display for Column {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match &self.table {
Some(t) => write!(f, "{}.{}", t, self.name),
None => write!(f, "{}", self.name),
}
}
}
/// Strip alias wrapper from an expression, returning the underlying expression.
/// This is useful when comparing expressions where one might be aliased and the other not,
/// such as when matching SELECT expressions with GROUP BY expressions.
pub fn strip_alias(expr: &LogicalExpr) -> &LogicalExpr {
match expr {
LogicalExpr::Alias { expr, .. } => expr,
_ => expr,
}
}
/// Type alias for binary operators
pub type BinaryOperator = ast::Operator;
/// Type alias for unary operators
pub type UnaryOperator = ast::UnaryOperator;
/// Type alias for aggregate functions
pub type AggregateFunction = AggFunc;
/// Compiler from AST to LogicalPlan
pub struct LogicalPlanBuilder<'a> {
schema: &'a Schema,
ctes: HashMap<String, Arc<LogicalPlan>>,
}
impl<'a> LogicalPlanBuilder<'a> {
pub fn new(schema: &'a Schema) -> Self {
Self {
schema,
ctes: HashMap::new(),
}
}
/// Main entry point: compile a statement to a logical plan
pub fn build_statement(&mut self, stmt: &ast::Stmt) -> Result<LogicalPlan> {
match stmt {
ast::Stmt::Select(select) => self.build_select(select),
_ => Err(LimboError::ParseError(
"Only SELECT statements are currently supported in logical plans".to_string(),
)),
}
}
// Convert Name to String
fn name_to_string(name: &ast::Name) -> String {
name.as_str().to_string()
}
// Build a SELECT statement
// Build a logical plan from a SELECT statement
fn build_select(&mut self, select: &ast::Select) -> Result<LogicalPlan> {
// Handle WITH clause if present
if let Some(with) = &select.with {
return self.build_with_cte(with, select);
}
// Build the main query body
let order_by = &select.order_by;
let limit = &select.limit;
self.build_select_body(&select.body, order_by, limit)
}
// Build WITH CTE
fn build_with_cte(&mut self, with: &ast::With, select: &ast::Select) -> Result<LogicalPlan> {
let mut cte_plans = HashMap::new();
// Build each CTE
for cte in &with.ctes {
let cte_plan = self.build_select(&cte.select)?;
let cte_name = Self::name_to_string(&cte.tbl_name);
cte_plans.insert(cte_name.clone(), Arc::new(cte_plan));
self.ctes
.insert(cte_name.clone(), cte_plans[&cte_name].clone());
}
// Build the main body with CTEs available
let order_by = &select.order_by;
let limit = &select.limit;
let body = self.build_select_body(&select.body, order_by, limit)?;
// Clear CTEs from builder context
for cte in &with.ctes {
self.ctes.remove(&Self::name_to_string(&cte.tbl_name));
}
Ok(LogicalPlan::WithCTE(WithCTE {
ctes: cte_plans,
body: Arc::new(body),
}))
}
// Build SELECT body
fn build_select_body(
&mut self,
body: &ast::SelectBody,
order_by: &[ast::SortedColumn],
limit: &Option<ast::Limit>,
) -> Result<LogicalPlan> {
let mut plan = self.build_one_select(&body.select)?;
// Handle compound operators (UNION, INTERSECT, EXCEPT)
if !body.compounds.is_empty() {
for compound in &body.compounds {
let right = self.build_one_select(&compound.select)?;
plan = Self::build_compound(plan, right, &compound.operator)?;
}
}
// Apply ORDER BY
if !order_by.is_empty() {
plan = self.build_sort(plan, order_by)?;
}
// Apply LIMIT
if let Some(limit) = limit {
plan = Self::build_limit(plan, limit)?;
}
Ok(plan)
}
// Build a single SELECT (without compounds)
fn build_one_select(&mut self, select: &ast::OneSelect) -> Result<LogicalPlan> {
match select {
ast::OneSelect::Select {
distinctness,
columns,
from,
where_clause,
group_by,
window_clause: _,
} => {
// Start with FROM clause
let mut plan = if let Some(from) = from {
self.build_from(from)?
} else {
// No FROM clause - single row
LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: Arc::new(LogicalSchema::empty()),
})
};
// Apply WHERE
if let Some(where_expr) = where_clause {
let predicate = self.build_expr(where_expr, plan.schema())?;
plan = LogicalPlan::Filter(Filter {
input: Arc::new(plan),
predicate,
});
}
// Apply GROUP BY and aggregations
if let Some(group_by) = group_by {
plan = self.build_aggregate(plan, group_by, columns)?;
} else if Self::has_aggregates(columns) {
// Aggregation without GROUP BY
plan = self.build_aggregate_no_group(plan, columns)?;
} else {
// Regular projection
plan = self.build_projection(plan, columns)?;
}
// Apply HAVING (part of GROUP BY)
if let Some(ref group_by) = group_by {
if let Some(ref having_expr) = group_by.having {
let predicate = self.build_expr(having_expr, plan.schema())?;
plan = LogicalPlan::Filter(Filter {
input: Arc::new(plan),
predicate,
});
}
}
// Apply DISTINCT
if distinctness.is_some() {
plan = LogicalPlan::Distinct(Distinct {
input: Arc::new(plan),
});
}
Ok(plan)
}
ast::OneSelect::Values(values) => self.build_values(values),
}
}
// Build FROM clause
fn build_from(&mut self, from: &ast::FromClause) -> Result<LogicalPlan> {
let mut plan = { self.build_select_table(&from.select)? };
// Handle JOINs
if !from.joins.is_empty() {
for join in &from.joins {
let right = self.build_select_table(&join.table)?;
plan = self.build_join(plan, right, &join.operator, &join.constraint)?;
}
}
Ok(plan)
}
// Build a table reference
fn build_select_table(&mut self, table: &ast::SelectTable) -> Result<LogicalPlan> {
match table {
ast::SelectTable::Table(name, alias, _indexed) => {
let table_name = Self::name_to_string(&name.name);
// Check if it's a CTE reference
if let Some(cte_plan) = self.ctes.get(&table_name) {
return Ok(LogicalPlan::CTERef(CTERef {
name: table_name.clone(),
schema: cte_plan.schema().clone(),
}));
}
// Regular table scan
let table_alias = alias.as_ref().map(|a| match a {
ast::As::As(name) => Self::name_to_string(name),
ast::As::Elided(name) => Self::name_to_string(name),
});
let table_schema = self.get_table_schema(&table_name, table_alias.as_deref())?;
Ok(LogicalPlan::TableScan(TableScan {
table_name,
alias: table_alias.clone(),
schema: table_schema,
projection: None,
}))
}
ast::SelectTable::Select(subquery, _alias) => self.build_select(subquery),
ast::SelectTable::TableCall(_, _, _) => Err(LimboError::ParseError(
"Table-valued functions are not supported in logical plans".to_string(),
)),
ast::SelectTable::Sub(_, _) => Err(LimboError::ParseError(
"Subquery in FROM clause not yet supported".to_string(),
)),
}
}
// Build JOIN
fn build_join(
&mut self,
left: LogicalPlan,
right: LogicalPlan,
op: &ast::JoinOperator,
constraint: &Option<ast::JoinConstraint>,
) -> Result<LogicalPlan> {
// Determine join type
let join_type = match op {
ast::JoinOperator::Comma => JoinType::Cross, // Comma is essentially a cross join
ast::JoinOperator::TypedJoin(Some(jt)) => {
// Check the join type flags
// Note: JoinType can have multiple flags set
if jt.contains(ast::JoinType::NATURAL) {
// Natural joins need special handling - find common columns
return self.build_natural_join(left, right, JoinType::Inner);
} else if jt.contains(ast::JoinType::LEFT)
&& jt.contains(ast::JoinType::RIGHT)
&& jt.contains(ast::JoinType::OUTER)
{
// FULL OUTER JOIN (has LEFT, RIGHT, and OUTER)
JoinType::Full
} else if jt.contains(ast::JoinType::LEFT) && jt.contains(ast::JoinType::OUTER) {
JoinType::Left
} else if jt.contains(ast::JoinType::RIGHT) && jt.contains(ast::JoinType::OUTER) {
JoinType::Right
} else if jt.contains(ast::JoinType::OUTER)
&& !jt.contains(ast::JoinType::LEFT)
&& !jt.contains(ast::JoinType::RIGHT)
{
// Plain OUTER JOIN should also be FULL
JoinType::Full
} else if jt.contains(ast::JoinType::LEFT) {
JoinType::Left
} else if jt.contains(ast::JoinType::RIGHT) {
JoinType::Right
} else if jt.contains(ast::JoinType::CROSS)
|| (jt.contains(ast::JoinType::INNER) && jt.contains(ast::JoinType::CROSS))
{
JoinType::Cross
} else {
JoinType::Inner // Default to inner
}
}
ast::JoinOperator::TypedJoin(None) => JoinType::Inner, // Default JOIN is INNER JOIN
};
// Build join conditions
let (on_conditions, filter) = match constraint {
Some(ast::JoinConstraint::On(expr)) => {
// Parse ON clause into equijoin conditions and filters
self.parse_join_conditions(expr, left.schema(), right.schema())?
}
Some(ast::JoinConstraint::Using(columns)) => {
// Build equijoin conditions from USING clause
let on = self.build_using_conditions(columns, left.schema(), right.schema())?;
(on, None)
}
None => {
// Cross join or natural join
(Vec::new(), None)
}
};
// Build combined schema
let schema = self.build_join_schema(&left, &right, &join_type)?;
Ok(LogicalPlan::Join(Join {
left: Arc::new(left),
right: Arc::new(right),
join_type,
on: on_conditions,
filter,
schema,
}))
}
// Helper: Parse join conditions into equijoins and filters
fn parse_join_conditions(
&mut self,
expr: &ast::Expr,
left_schema: &SchemaRef,
right_schema: &SchemaRef,
) -> Result<JoinConditionsResult> {
// For now, we'll handle simple equality conditions
// More complex conditions will go into the filter
let mut equijoins = Vec::new();
let mut filters = Vec::new();
// Try to extract equijoin conditions from the expression
self.extract_equijoin_conditions(
expr,
left_schema,
right_schema,
&mut equijoins,
&mut filters,
)?;
let filter = if filters.is_empty() {
None
} else {
// Combine multiple filters with AND
Some(
filters
.into_iter()
.reduce(|acc, e| LogicalExpr::BinaryExpr {
left: Box::new(acc),
op: BinaryOperator::And,
right: Box::new(e),
})
.unwrap(),
)
};
Ok((equijoins, filter))
}
// Helper: Extract equijoin conditions from expression
fn extract_equijoin_conditions(
&mut self,
expr: &ast::Expr,
left_schema: &SchemaRef,
right_schema: &SchemaRef,
equijoins: &mut Vec<(LogicalExpr, LogicalExpr)>,
filters: &mut Vec<LogicalExpr>,
) -> Result<()> {
match expr {
ast::Expr::Binary(lhs, ast::Operator::Equals, rhs) => {
// Check if this is an equijoin condition (left.col = right.col)
let left_expr = self.build_expr(lhs, left_schema)?;
let right_expr = self.build_expr(rhs, right_schema)?;
// For simplicity, we'll check if one references left and one references right
// In a real implementation, we'd need more sophisticated column resolution
equijoins.push((left_expr, right_expr));
}
ast::Expr::Binary(lhs, ast::Operator::And, rhs) => {
// Recursively extract from AND conditions
self.extract_equijoin_conditions(
lhs,
left_schema,
right_schema,
equijoins,
filters,
)?;
self.extract_equijoin_conditions(
rhs,
left_schema,
right_schema,
equijoins,
filters,
)?;
}
_ => {
// Other conditions go into the filter
// We need a combined schema to build the expression
let combined_schema = self.combine_schemas(left_schema, right_schema)?;
let filter_expr = self.build_expr(expr, &combined_schema)?;
filters.push(filter_expr);
}
}
Ok(())
}
// Helper: Build equijoin conditions from USING clause
fn build_using_conditions(
&mut self,
columns: &[ast::Name],
left_schema: &SchemaRef,
right_schema: &SchemaRef,
) -> Result<Vec<(LogicalExpr, LogicalExpr)>> {
let mut conditions = Vec::new();
for col_name in columns {
let name = Self::name_to_string(col_name);
// Find the column in both schemas
let _left_idx = left_schema
.columns
.iter()
.position(|col| col.name == name)
.ok_or_else(|| {
LimboError::ParseError(format!("Column {name} not found in left table"))
})?;
let _right_idx = right_schema
.columns
.iter()
.position(|col| col.name == name)
.ok_or_else(|| {
LimboError::ParseError(format!("Column {name} not found in right table"))
})?;
conditions.push((
LogicalExpr::Column(Column {
name: name.clone(),
table: None, // Will be resolved later
}),
LogicalExpr::Column(Column {
name,
table: None, // Will be resolved later
}),
));
}
Ok(conditions)
}
// Helper: Build natural join by finding common columns
fn build_natural_join(
&mut self,
left: LogicalPlan,
right: LogicalPlan,
join_type: JoinType,
) -> Result<LogicalPlan> {
let left_schema = left.schema();
let right_schema = right.schema();
// Find common column names
let mut common_columns = Vec::new();
for left_col in &left_schema.columns {
if right_schema
.columns
.iter()
.any(|col| col.name == left_col.name)
{
common_columns.push(ast::Name::exact(left_col.name.clone()));
}
}
if common_columns.is_empty() {
// Natural join with no common columns becomes a cross join
let schema = self.build_join_schema(&left, &right, &JoinType::Cross)?;
return Ok(LogicalPlan::Join(Join {
left: Arc::new(left),
right: Arc::new(right),
join_type: JoinType::Cross,
on: Vec::new(),
filter: None,
schema,
}));
}
// Build equijoin conditions for common columns
let on = self.build_using_conditions(&common_columns, left_schema, right_schema)?;
let schema = self.build_join_schema(&left, &right, &join_type)?;
Ok(LogicalPlan::Join(Join {
left: Arc::new(left),
right: Arc::new(right),
join_type,
on,
filter: None,
schema,
}))
}
// Helper: Build schema for join result
fn build_join_schema(
&self,
left: &LogicalPlan,
right: &LogicalPlan,
_join_type: &JoinType,
) -> Result<SchemaRef> {
let left_schema = left.schema();
let right_schema = right.schema();
// Concatenate the schemas, preserving all column information
let mut columns = Vec::new();
// Keep all columns from left with their table info
for col in &left_schema.columns {
columns.push(col.clone());
}
// Keep all columns from right with their table info
for col in &right_schema.columns {
columns.push(col.clone());
}
Ok(Arc::new(LogicalSchema::new(columns)))
}
// Helper: Combine two schemas for expression building
fn combine_schemas(&self, left: &SchemaRef, right: &SchemaRef) -> Result<SchemaRef> {
let mut columns = left.columns.clone();
columns.extend(right.columns.clone());
Ok(Arc::new(LogicalSchema::new(columns)))
}
// Build projection
fn build_projection(
&mut self,
input: LogicalPlan,
columns: &[ast::ResultColumn],
) -> Result<LogicalPlan> {
let input_schema = input.schema();
let mut proj_exprs = Vec::new();
let mut schema_columns = Vec::new();
for col in columns {
match col {
ast::ResultColumn::Expr(expr, alias) => {
let logical_expr = self.build_expr(expr, input_schema)?;
let col_name = match alias {
Some(as_alias) => match as_alias {
ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name),
},
None => Self::expr_to_column_name(expr),
};
let col_type = Self::infer_expr_type(&logical_expr, input_schema)?;
schema_columns.push(ColumnInfo {
name: col_name.clone(),
ty: col_type,
database: None,
table: None,
table_alias: None,
});
if let Some(as_alias) = alias {
let alias_name = match as_alias {
ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name),
};
proj_exprs.push(LogicalExpr::Alias {
expr: Box::new(logical_expr),
alias: alias_name,
});
} else {
proj_exprs.push(logical_expr);
}
}
ast::ResultColumn::Star => {
// Expand * to all columns
for col in &input_schema.columns {
proj_exprs.push(LogicalExpr::Column(Column::new(col.name.clone())));
schema_columns.push(col.clone());
}
}
ast::ResultColumn::TableStar(table) => {
// Expand table.* to all columns from that table
let table_name = Self::name_to_string(table);
for col in &input_schema.columns {
// Simple check - would need proper table tracking in real implementation
proj_exprs.push(LogicalExpr::Column(Column::with_table(
col.name.clone(),
table_name.clone(),
)));
schema_columns.push(col.clone());
}
}
}
}
Ok(LogicalPlan::Projection(Projection {
input: Arc::new(input),
exprs: proj_exprs,
schema: Arc::new(LogicalSchema::new(schema_columns)),
}))
}
// Helper function to preprocess aggregate expressions that contain complex arguments
// Returns: (needs_pre_projection, pre_projection_exprs, pre_projection_schema, modified_aggr_exprs)
//
// This will be used in expressions like select sum(hex(a + 2)) from tbl => hex(a + 2) is a
// pre-projection.
//
// Another alternative is to always generate a projection together with an aggregation, and
// just have "a" be the identity projection if we don't have a complex case. But that's quite
// wasteful.
fn preprocess_aggregate_expressions(
aggr_exprs: &[LogicalExpr],
group_exprs: &[LogicalExpr],
input_schema: &SchemaRef,
) -> Result<PreprocessAggregateResult> {
let mut needs_pre_projection = false;
let mut pre_projection_exprs = Vec::new();
let mut pre_projection_schema = Vec::new();
let mut modified_aggr_exprs = Vec::new();
let mut modified_group_exprs = Vec::new();
let mut projected_col_counter = 0;
// First, add all group by expressions to the pre-projection
for expr in group_exprs {
if let LogicalExpr::Column(col) = expr {
pre_projection_exprs.push(expr.clone());
let col_type = Self::infer_expr_type(expr, input_schema)?;
pre_projection_schema.push(ColumnInfo {
name: col.name.clone(),
ty: col_type,
database: None,
table: col.table.clone(),
table_alias: None,
});
// Column references stay as-is in the modified group expressions
modified_group_exprs.push(expr.clone());
} else {
// Complex group by expression - project it
needs_pre_projection = true;
let proj_col_name = format!("__group_proj_{projected_col_counter}");
projected_col_counter += 1;
pre_projection_exprs.push(expr.clone());
let col_type = Self::infer_expr_type(expr, input_schema)?;
pre_projection_schema.push(ColumnInfo {
name: proj_col_name.clone(),
ty: col_type,
database: None,
table: None,
table_alias: None,
});
// Replace complex expression with reference to projected column
modified_group_exprs.push(LogicalExpr::Column(Column {
name: proj_col_name,
table: None,
}));
}
}
// Check each aggregate expression
for agg_expr in aggr_exprs {
if let LogicalExpr::AggregateFunction {
fun,
args,
distinct,
} = agg_expr
{
let mut modified_args = Vec::new();
for arg in args {
// Check if the argument is a simple column reference or a complex expression
match arg {
LogicalExpr::Column(_) => {
// Simple column - just use it
modified_args.push(arg.clone());
// Make sure the column is in the pre-projection
if !pre_projection_exprs.iter().any(|e| e == arg) {
pre_projection_exprs.push(arg.clone());
let col_type = Self::infer_expr_type(arg, input_schema)?;
if let LogicalExpr::Column(col) = arg {
pre_projection_schema.push(ColumnInfo {
name: col.name.clone(),
ty: col_type,
database: None,
table: col.table.clone(),
table_alias: None,
});
}
}
}
_ => {
// Complex expression - we need to project it first
needs_pre_projection = true;
let proj_col_name = format!("__agg_arg_proj_{projected_col_counter}");
projected_col_counter += 1;
// Add the expression to the pre-projection
pre_projection_exprs.push(arg.clone());
let col_type = Self::infer_expr_type(arg, input_schema)?;
pre_projection_schema.push(ColumnInfo {
name: proj_col_name.clone(),
ty: col_type,
database: None,
table: None,
table_alias: None,
});
// In the aggregate, reference the projected column
modified_args.push(LogicalExpr::Column(Column::new(proj_col_name)));
}
}
}
// Create the modified aggregate expression
modified_aggr_exprs.push(LogicalExpr::AggregateFunction {
fun: fun.clone(),
args: modified_args,
distinct: *distinct,
});
} else {
modified_aggr_exprs.push(agg_expr.clone());
}
}
Ok((
needs_pre_projection,
pre_projection_exprs,
pre_projection_schema,
modified_aggr_exprs,
modified_group_exprs,
))
}
// Build aggregate with GROUP BY
fn build_aggregate(
&mut self,
input: LogicalPlan,
group_by: &ast::GroupBy,
columns: &[ast::ResultColumn],
) -> Result<LogicalPlan> {
let input_schema = input.schema();
// Build grouping expressions
let mut group_exprs = Vec::new();
for expr in &group_by.exprs {
group_exprs.push(self.build_expr(expr, input_schema)?);
}
// Use the unified aggregate builder
self.build_aggregate_internal(input, group_exprs, columns)
}
// Build aggregate without GROUP BY
fn build_aggregate_no_group(
&mut self,
input: LogicalPlan,
columns: &[ast::ResultColumn],
) -> Result<LogicalPlan> {
// Use the unified aggregate builder with empty group expressions
self.build_aggregate_internal(input, vec![], columns)
}
// Unified internal aggregate builder that handles both GROUP BY and non-GROUP BY cases
fn build_aggregate_internal(
&mut self,
input: LogicalPlan,
group_exprs: Vec<LogicalExpr>,
columns: &[ast::ResultColumn],
) -> Result<LogicalPlan> {
let input_schema = input.schema();
let has_group_by = !group_exprs.is_empty();
// First pass: build a map of aliases to expressions from the SELECT list
// and a vector of SELECT expressions for positional references
// This allows GROUP BY to reference SELECT aliases (e.g., GROUP BY year)
// or positions (e.g., GROUP BY 1)
let mut alias_to_expr = HashMap::new();
let mut select_exprs = Vec::new();
for col in columns {
if let ast::ResultColumn::Expr(expr, alias) = col {
let logical_expr = self.build_expr(expr, input_schema)?;
select_exprs.push(logical_expr.clone());
if let Some(alias) = alias {
let alias_name = match alias {
ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name),
};
alias_to_expr.insert(alias_name, logical_expr);
}
}
}
// Resolve GROUP BY expressions: replace column references that match SELECT aliases
// or integer literals that represent positions
let group_exprs = group_exprs
.into_iter()
.map(|expr| {
// Check for positional reference (integer literal)
if let LogicalExpr::Literal(crate::types::Value::Integer(pos)) = &expr {
// SQLite uses 1-based indexing
if *pos > 0 && (*pos as usize) <= select_exprs.len() {
return select_exprs[(*pos as usize) - 1].clone();
}
}
// Check for alias reference (unqualified column name)
if let LogicalExpr::Column(col) = &expr {
if col.table.is_none() {
// Unqualified column - check if it matches an alias
if let Some(aliased_expr) = alias_to_expr.get(&col.name) {
return aliased_expr.clone();
}
}
}
expr
})
.collect::<Vec<_>>();
// Build aggregate expressions and projection expressions
let mut aggr_exprs = Vec::new();
let mut projection_exprs = Vec::new();
let mut aggregate_schema_columns = Vec::new();
// First, add GROUP BY columns to the aggregate output schema
// These are always part of the aggregate operator's output
for group_expr in &group_exprs {
match group_expr {
LogicalExpr::Column(col) => {
// For column references in GROUP BY, preserve the original column info
if let Some((_, col_info)) =
input_schema.find_column(&col.name, col.table.as_deref())
{
// Preserve the column with all its table information
aggregate_schema_columns.push(col_info.clone());
} else {
// Fallback if column not found (shouldn't happen)
let col_type = Self::infer_expr_type(group_expr, input_schema)?;
aggregate_schema_columns.push(ColumnInfo {
name: col.name.clone(),
ty: col_type,
database: None,
table: col.table.clone(),
table_alias: None,
});
}
}
_ => {
// For complex GROUP BY expressions, generate a name
let col_name = format!("__group_{}", aggregate_schema_columns.len());
let col_type = Self::infer_expr_type(group_expr, input_schema)?;
aggregate_schema_columns.push(ColumnInfo {
name: col_name,
ty: col_type,
database: None,
table: None,
table_alias: None,
});
}
}
}
// Track aggregates we've already seen to avoid duplicates
let mut aggregate_map: HashMap<String, String> = HashMap::new();
for col in columns {
match col {
ast::ResultColumn::Expr(expr, alias) => {
let logical_expr = self.build_expr(expr, input_schema)?;
// Determine the column name for this expression
let col_name = match alias {
Some(as_alias) => match as_alias {
ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name),
},
None => Self::expr_to_column_name(expr),
};
// Check if the TOP-LEVEL expression is an aggregate
// We only care about immediate aggregates, not nested ones
if Self::is_aggregate_expr(&logical_expr) {
// Pure aggregate function - check if we've seen it before
let agg_key = format!("{logical_expr:?}");
let agg_col_name = if let Some(existing_name) = aggregate_map.get(&agg_key)
{
// Reuse existing aggregate
existing_name.clone()
} else {
// New aggregate - add it
let col_type = Self::infer_expr_type(&logical_expr, input_schema)?;
aggregate_schema_columns.push(ColumnInfo {
name: col_name.clone(),
ty: col_type,
database: None,
table: None,
table_alias: None,
});
aggr_exprs.push(logical_expr);
aggregate_map.insert(agg_key, col_name.clone());
col_name.clone()
};
// In the projection, reference this aggregate by name
projection_exprs.push(LogicalExpr::Column(Column {
name: agg_col_name,
table: None,
}));
} else if Self::contains_aggregate(&logical_expr) {
// This is an expression that contains an aggregate somewhere
// (e.g., sum(a + 2) * 2)
// We need to extract aggregates and replace them with column references
let (processed_expr, extracted_aggs) =
Self::extract_and_replace_aggregates_with_dedup(
logical_expr,
&mut aggregate_map,
)?;
// Add only new aggregates
for (agg_expr, agg_name) in extracted_aggs {
let agg_type = Self::infer_expr_type(&agg_expr, input_schema)?;
aggregate_schema_columns.push(ColumnInfo {
name: agg_name,
ty: agg_type,
database: None,
table: None,
table_alias: None,
});
aggr_exprs.push(agg_expr);
}
// Add the processed expression (with column refs) to projection
projection_exprs.push(processed_expr);
} else {
// Non-aggregate expression - validation depends on GROUP BY presence
if has_group_by {
// With GROUP BY: only allow constants and grouped columns
// TODO: SQLite actually allows any column here and returns the value from
// the first row encountered in each group. We should support this in the
// future for full SQLite compatibility, but for now we're stricter to
// simplify the DBSP compilation.
if !Self::is_constant_expr(&logical_expr)
&& !Self::is_valid_in_group_by(&logical_expr, &group_exprs)
{
return Err(LimboError::ParseError(format!(
"Column '{col_name}' must appear in the GROUP BY clause or be used in an aggregate function"
)));
}
// If this expression matches a GROUP BY expression, replace it with a reference
// to the corresponding column in the aggregate output
let logical_expr_stripped = strip_alias(&logical_expr);
if let Some(group_idx) = group_exprs
.iter()
.position(|g| logical_expr_stripped == strip_alias(g))
{
// Reference the GROUP BY column in the aggregate output by its name
let group_col_name = &aggregate_schema_columns[group_idx].name;
projection_exprs.push(LogicalExpr::Column(Column {
name: group_col_name.clone(),
table: None,
}));
} else {
projection_exprs.push(logical_expr);
}
} else {
// Without GROUP BY: only allow constant expressions
// TODO: SQLite allows any column here and returns a value from an
// arbitrary row. We should support this for full compatibility,
// but for now we're stricter to simplify DBSP compilation.
if !Self::is_constant_expr(&logical_expr) {
return Err(LimboError::ParseError(format!(
"Column '{col_name}' must be used in an aggregate function when using aggregates without GROUP BY"
)));
}
projection_exprs.push(logical_expr);
}
}
}
_ => {
let error_msg = if has_group_by {
"* not supported with GROUP BY".to_string()
} else {
"* not supported with aggregate functions".to_string()
};
return Err(LimboError::ParseError(error_msg));
}
}
}
// Check if any aggregate functions have complex expressions as arguments
// or if GROUP BY has complex expressions
// If so, we need to insert a projection before the aggregate
let (
needs_pre_projection,
pre_projection_exprs,
pre_projection_schema,
modified_aggr_exprs,
modified_group_exprs,
) = Self::preprocess_aggregate_expressions(&aggr_exprs, &group_exprs, input_schema)?;
// Build the final schema for the projection
let mut projection_schema_columns = Vec::new();
for (i, expr) in projection_exprs.iter().enumerate() {
let col_name = if i < columns.len() {
match &columns[i] {
ast::ResultColumn::Expr(e, alias) => match alias {
Some(as_alias) => match as_alias {
ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name),
},
None => Self::expr_to_column_name(e),
},
_ => format!("col_{i}"),
}
} else {
format!("col_{i}")
};
// For type inference, we need the aggregate schema for column references
let aggregate_schema = LogicalSchema::new(aggregate_schema_columns.clone());
let col_type = Self::infer_expr_type(expr, &Arc::new(aggregate_schema))?;
projection_schema_columns.push(ColumnInfo {
name: col_name,
ty: col_type,
database: None,
table: None,
table_alias: None,
});
}
// Create the input plan (with pre-projection if needed)
let aggregate_input = if needs_pre_projection {
Arc::new(LogicalPlan::Projection(Projection {
input: Arc::new(input),
exprs: pre_projection_exprs,
schema: Arc::new(LogicalSchema::new(pre_projection_schema)),
}))
} else {
Arc::new(input)
};
// Use modified aggregate and group expressions if we inserted a pre-projection
let final_aggr_exprs = if needs_pre_projection {
modified_aggr_exprs
} else {
aggr_exprs
};
let final_group_exprs = if needs_pre_projection {
modified_group_exprs
} else {
group_exprs
};
// Check if we need the outer projection
// We need a projection if:
// 1. We have expressions that compute new values (e.g., SUM(x) * 2)
// 2. We're selecting a different set of columns than GROUP BY + aggregates
// 3. We're reordering columns from their natural aggregate output order
let needs_outer_projection = {
// Check for complex expressions
let has_complex_exprs = projection_exprs
.iter()
.any(|expr| !matches!(expr, LogicalExpr::Column(_)));
if has_complex_exprs {
true
} else {
// Check if we're selecting exactly what aggregate outputs in the same order
// The aggregate outputs: all GROUP BY columns, then all aggregate expressions
// The projection might select a subset or reorder these
if projection_exprs.len() != aggregate_schema_columns.len() {
// Different number of columns
true
} else {
// Check if columns match in order and name
!projection_exprs.iter().zip(&aggregate_schema_columns).all(
|(expr, agg_col)| {
if let LogicalExpr::Column(col) = expr {
col.name == agg_col.name
} else {
false
}
},
)
}
}
};
// Create the aggregate node with its natural schema
let aggregate_plan = LogicalPlan::Aggregate(Aggregate {
input: aggregate_input,
group_expr: final_group_exprs,
aggr_expr: final_aggr_exprs,
schema: Arc::new(LogicalSchema::new(aggregate_schema_columns)),
});
if needs_outer_projection {
Ok(LogicalPlan::Projection(Projection {
input: Arc::new(aggregate_plan),
exprs: projection_exprs,
schema: Arc::new(LogicalSchema::new(projection_schema_columns)),
}))
} else {
// No projection needed - aggregate output matches what we want
Ok(aggregate_plan)
}
}
/// Build VALUES clause
#[allow(clippy::vec_box)]
fn build_values(&mut self, values: &[Vec<Box<ast::Expr>>]) -> Result<LogicalPlan> {
if values.is_empty() {
return Err(LimboError::ParseError("Empty VALUES clause".to_string()));
}
let mut rows = Vec::new();
let first_row_len = values[0].len();
// Infer schema from first row
let mut schema_columns = Vec::new();
for (i, _) in values[0].iter().enumerate() {
schema_columns.push(ColumnInfo {
name: format!("column{}", i + 1),
ty: Type::Text,
database: None,
table: None,
table_alias: None,
});
}
for row in values {
if row.len() != first_row_len {
return Err(LimboError::ParseError(
"All rows in VALUES must have the same number of columns".to_string(),
));
}
let mut logical_row = Vec::new();
for expr in row {
// VALUES doesn't have input schema
let empty_schema = Arc::new(LogicalSchema::empty());
logical_row.push(self.build_expr(expr, &empty_schema)?);
}
rows.push(logical_row);
}
Ok(LogicalPlan::Values(Values {
rows,
schema: Arc::new(LogicalSchema::new(schema_columns)),
}))
}
// Build SORT
fn build_sort(
&mut self,
input: LogicalPlan,
exprs: &[ast::SortedColumn],
) -> Result<LogicalPlan> {
let input_schema = input.schema();
let mut sort_exprs = Vec::new();
for sorted_col in exprs {
let expr = self.build_expr(&sorted_col.expr, input_schema)?;
sort_exprs.push(SortExpr {
expr,
asc: sorted_col.order != Some(ast::SortOrder::Desc),
nulls_first: sorted_col.nulls == Some(ast::NullsOrder::First),
});
}
Ok(LogicalPlan::Sort(Sort {
input: Arc::new(input),
exprs: sort_exprs,
}))
}
// Build LIMIT
fn build_limit(input: LogicalPlan, limit: &ast::Limit) -> Result<LogicalPlan> {
let fetch = match limit.expr.as_ref() {
ast::Expr::Literal(ast::Literal::Numeric(s)) => s.parse::<usize>().ok(),
_ => {
return Err(LimboError::ParseError(
"LIMIT must be a literal integer".to_string(),
))
}
};
let skip = if let Some(offset) = &limit.offset {
match offset.as_ref() {
ast::Expr::Literal(ast::Literal::Numeric(s)) => s.parse::<usize>().ok(),
_ => {
return Err(LimboError::ParseError(
"OFFSET must be a literal integer".to_string(),
))
}
}
} else {
None
};
Ok(LogicalPlan::Limit(Limit {
input: Arc::new(input),
skip,
fetch,
}))
}
// Build compound operator (UNION, INTERSECT, EXCEPT)
fn build_compound(
left: LogicalPlan,
right: LogicalPlan,
op: &ast::CompoundOperator,
) -> Result<LogicalPlan> {
// Check schema compatibility
if left.schema().column_count() != right.schema().column_count() {
return Err(LimboError::ParseError(
"UNION/INTERSECT/EXCEPT requires same number of columns".to_string(),
));
}
let all = matches!(op, ast::CompoundOperator::UnionAll);
match op {
ast::CompoundOperator::Union | ast::CompoundOperator::UnionAll => {
let schema = left.schema().clone();
Ok(LogicalPlan::Union(Union {
inputs: vec![Arc::new(left), Arc::new(right)],
all,
schema,
}))
}
_ => Err(LimboError::ParseError(
"INTERSECT and EXCEPT not yet supported in logical plans".to_string(),
)),
}
}
// Build expression from AST
fn build_expr(&mut self, expr: &ast::Expr, _schema: &SchemaRef) -> Result<LogicalExpr> {
match expr {
ast::Expr::Id(name) => Ok(LogicalExpr::Column(Column::new(Self::name_to_string(name)))),
ast::Expr::DoublyQualified(db, table, col) => {
Ok(LogicalExpr::Column(Column::with_table(
Self::name_to_string(col),
format!(
"{}.{}",
Self::name_to_string(db),
Self::name_to_string(table)
),
)))
}
ast::Expr::Qualified(table, col) => Ok(LogicalExpr::Column(Column::with_table(
Self::name_to_string(col),
Self::name_to_string(table),
))),
ast::Expr::Literal(lit) => Ok(LogicalExpr::Literal(Self::build_literal(lit)?)),
ast::Expr::Binary(lhs, op, rhs) => {
// Special case: IS NULL and IS NOT NULL
if matches!(op, ast::Operator::Is | ast::Operator::IsNot) {
if let ast::Expr::Literal(ast::Literal::Null) = rhs.as_ref() {
let expr = Box::new(self.build_expr(lhs, _schema)?);
return Ok(LogicalExpr::IsNull {
expr,
negated: matches!(op, ast::Operator::IsNot),
});
}
}
let left = Box::new(self.build_expr(lhs, _schema)?);
let right = Box::new(self.build_expr(rhs, _schema)?);
Ok(LogicalExpr::BinaryExpr {
left,
op: *op,
right,
})
}
ast::Expr::Unary(op, expr) => {
let inner = Box::new(self.build_expr(expr, _schema)?);
Ok(LogicalExpr::UnaryExpr {
op: *op,
expr: inner,
})
}
ast::Expr::FunctionCall {
name,
distinctness,
args,
filter_over,
..
} => {
// Check for window functions (OVER clause)
if filter_over.over_clause.is_some() {
return Err(LimboError::ParseError(
"Unsupported expression type: window functions are not yet supported"
.to_string(),
));
}
let func_name = Self::name_to_string(name);
let arg_count = args.len();
// Check if it's an aggregate function (considering argument count for min/max)
if let Some(agg_fun) = Self::parse_aggregate_function(&func_name, arg_count) {
let distinct = distinctness.is_some();
let arg_exprs = args
.iter()
.map(|e| self.build_expr(e, _schema))
.collect::<Result<Vec<_>>>()?;
Ok(LogicalExpr::AggregateFunction {
fun: agg_fun,
args: arg_exprs,
distinct,
})
} else {
// Regular scalar function
let arg_exprs = args
.iter()
.map(|e| self.build_expr(e, _schema))
.collect::<Result<Vec<_>>>()?;
Ok(LogicalExpr::ScalarFunction {
fun: func_name,
args: arg_exprs,
})
}
}
ast::Expr::FunctionCallStar { name, .. } => {
// Handle COUNT(*) and similar
let func_name = Self::name_to_string(name);
// FunctionCallStar always has 0 args (it's the * form)
if let Some(agg_fun) = Self::parse_aggregate_function(&func_name, 0) {
Ok(LogicalExpr::AggregateFunction {
fun: agg_fun,
args: vec![],
distinct: false,
})
} else {
Err(LimboError::ParseError(format!(
"Function {func_name}(*) is not supported"
)))
}
}
ast::Expr::Case {
base,
when_then_pairs,
else_expr,
} => {
let case_expr = if let Some(e) = base {
Some(Box::new(self.build_expr(e, _schema)?))
} else {
None
};
let when_then_exprs = when_then_pairs
.iter()
.map(|(when, then)| {
Ok((
self.build_expr(when, _schema)?,
self.build_expr(then, _schema)?,
))
})
.collect::<Result<Vec<_>>>()?;
let else_result = if let Some(e) = else_expr {
Some(Box::new(self.build_expr(e, _schema)?))
} else {
None
};
Ok(LogicalExpr::Case {
expr: case_expr,
when_then: when_then_exprs,
else_expr: else_result,
})
}
ast::Expr::InList { lhs, not, rhs } => {
let expr = Box::new(self.build_expr(lhs, _schema)?);
let list = rhs
.iter()
.map(|e| self.build_expr(e, _schema))
.collect::<Result<Vec<_>>>()?;
Ok(LogicalExpr::InList {
expr,
list,
negated: *not,
})
}
ast::Expr::InSelect { lhs, not, rhs } => {
let expr = Box::new(self.build_expr(lhs, _schema)?);
let subquery = Arc::new(self.build_select(rhs)?);
Ok(LogicalExpr::InSubquery {
expr,
subquery,
negated: *not,
})
}
ast::Expr::Exists(select) => {
let subquery = Arc::new(self.build_select(select)?);
Ok(LogicalExpr::Exists {
subquery,
negated: false,
})
}
ast::Expr::Subquery(select) => {
let subquery = Arc::new(self.build_select(select)?);
Ok(LogicalExpr::ScalarSubquery(subquery))
}
ast::Expr::IsNull(lhs) => {
let expr = Box::new(self.build_expr(lhs, _schema)?);
Ok(LogicalExpr::IsNull {
expr,
negated: false,
})
}
ast::Expr::NotNull(lhs) => {
let expr = Box::new(self.build_expr(lhs, _schema)?);
Ok(LogicalExpr::IsNull {
expr,
negated: true,
})
}
ast::Expr::Between {
lhs,
not,
start,
end,
} => {
let expr = Box::new(self.build_expr(lhs, _schema)?);
let low = Box::new(self.build_expr(start, _schema)?);
let high = Box::new(self.build_expr(end, _schema)?);
Ok(LogicalExpr::Between {
expr,
low,
high,
negated: *not,
})
}
ast::Expr::Like {
lhs,
not,
op: _,
rhs,
escape,
} => {
let expr = Box::new(self.build_expr(lhs, _schema)?);
let pattern = Box::new(self.build_expr(rhs, _schema)?);
let escape_char = escape.as_ref().and_then(|e| {
if let ast::Expr::Literal(ast::Literal::String(s)) = e.as_ref() {
s.chars().next()
} else {
None
}
});
Ok(LogicalExpr::Like {
expr,
pattern,
escape: escape_char,
negated: *not,
})
}
ast::Expr::Parenthesized(exprs) => {
// the assumption is that there is at least one parenthesis here.
// If this is not true, then I don't understand this code and can't be trusted.
assert!(!exprs.is_empty());
// Multiple expressions in parentheses is unusual but handle it
// by building the first one (SQLite behavior)
self.build_expr(&exprs[0], _schema)
}
ast::Expr::Cast { expr, type_name } => {
let inner = self.build_expr(expr, _schema)?;
Ok(LogicalExpr::Cast {
expr: Box::new(inner),
type_name: type_name.clone(),
})
}
_ => Err(LimboError::ParseError(format!(
"Unsupported expression type in logical plan: {expr:?}"
))),
}
}
/// Build literal value
fn build_literal(lit: &ast::Literal) -> Result<Value> {
match lit {
ast::Literal::Null => Ok(Value::Null),
ast::Literal::Keyword(k) => {
let k_bytes = k.as_bytes();
match_ignore_ascii_case!(match k_bytes {
b"true" => Ok(Value::Integer(1)), // SQLite uses int for bool
b"false" => Ok(Value::Integer(0)), // SQLite uses int for bool
_ => Ok(Value::Text(k.clone().into())),
})
}
ast::Literal::Numeric(s) => {
if let Ok(i) = s.parse::<i64>() {
Ok(Value::Integer(i))
} else if let Ok(f) = s.parse::<f64>() {
Ok(Value::Float(f))
} else {
Ok(Value::Text(s.clone().into()))
}
}
ast::Literal::String(s) => {
// Strip surrounding quotes from the SQL literal
// The parser includes quotes in the string value
let unquoted = if s.starts_with('\'') && s.ends_with('\'') && s.len() > 1 {
&s[1..s.len() - 1]
} else {
s.as_str()
};
Ok(Value::Text(unquoted.to_string().into()))
}
ast::Literal::Blob(b) => Ok(Value::Blob(b.clone().into())),
ast::Literal::CurrentDate
| ast::Literal::CurrentTime
| ast::Literal::CurrentTimestamp => Err(LimboError::ParseError(
"Temporal literals not yet supported".to_string(),
)),
}
}
/// Parse aggregate function name (considering argument count for min/max)
fn parse_aggregate_function(name: &str, arg_count: usize) -> Option<AggregateFunction> {
let name_bytes = name.as_bytes();
match_ignore_ascii_case!(match name_bytes {
b"COUNT" => Some(AggFunc::Count),
b"SUM" => Some(AggFunc::Sum),
b"AVG" => Some(AggFunc::Avg),
// MIN and MAX are only aggregates with 1 argument
// With 2+ arguments, they're scalar functions
b"MIN" if arg_count == 1 => Some(AggFunc::Min),
b"MAX" if arg_count == 1 => Some(AggFunc::Max),
b"GROUP_CONCAT" => Some(AggFunc::GroupConcat),
b"STRING_AGG" => Some(AggFunc::StringAgg),
b"TOTAL" => Some(AggFunc::Total),
_ => None,
})
}
// Check if expression contains aggregates
fn has_aggregates(columns: &[ast::ResultColumn]) -> bool {
for col in columns {
if let ast::ResultColumn::Expr(expr, _) = col {
if Self::expr_has_aggregate(expr) {
return true;
}
}
}
false
}
// Check if AST expression contains aggregates
fn expr_has_aggregate(expr: &ast::Expr) -> bool {
match expr {
ast::Expr::FunctionCall { name, args, .. } => {
// Check if the function itself is an aggregate (considering arg count for min/max)
let arg_count = args.len();
if Self::parse_aggregate_function(&Self::name_to_string(name), arg_count).is_some()
{
return true;
}
// Also check if any arguments contain aggregates (for nested functions like HEX(SUM(...)))
args.iter().any(|arg| Self::expr_has_aggregate(arg))
}
ast::Expr::FunctionCallStar { name, .. } => {
// FunctionCallStar always has 0 args
Self::parse_aggregate_function(&Self::name_to_string(name), 0).is_some()
}
ast::Expr::Binary(lhs, _, rhs) => {
Self::expr_has_aggregate(lhs) || Self::expr_has_aggregate(rhs)
}
ast::Expr::Unary(_, e) => Self::expr_has_aggregate(e),
ast::Expr::Case {
when_then_pairs,
else_expr,
..
} => {
when_then_pairs
.iter()
.any(|(w, t)| Self::expr_has_aggregate(w) || Self::expr_has_aggregate(t))
|| else_expr
.as_ref()
.is_some_and(|e| Self::expr_has_aggregate(e))
}
ast::Expr::Parenthesized(exprs) => {
// Check if any parenthesized expression contains an aggregate
exprs.iter().any(|e| Self::expr_has_aggregate(e))
}
_ => false,
}
}
// Check if logical expression is an aggregate
fn is_aggregate_expr(expr: &LogicalExpr) -> bool {
match expr {
LogicalExpr::AggregateFunction { .. } => true,
LogicalExpr::Alias { expr, .. } => Self::is_aggregate_expr(expr),
_ => false,
}
}
// Check if logical expression contains an aggregate anywhere
fn contains_aggregate(expr: &LogicalExpr) -> bool {
match expr {
LogicalExpr::AggregateFunction { .. } => true,
LogicalExpr::Alias { expr, .. } => Self::contains_aggregate(expr),
LogicalExpr::BinaryExpr { left, right, .. } => {
Self::contains_aggregate(left) || Self::contains_aggregate(right)
}
LogicalExpr::UnaryExpr { expr, .. } => Self::contains_aggregate(expr),
LogicalExpr::ScalarFunction { args, .. } => args.iter().any(Self::contains_aggregate),
LogicalExpr::Case {
when_then,
else_expr,
..
} => {
when_then
.iter()
.any(|(w, t)| Self::contains_aggregate(w) || Self::contains_aggregate(t))
|| else_expr
.as_ref()
.is_some_and(|e| Self::contains_aggregate(e))
}
_ => false,
}
}
// Check if an expression is a constant (contains only literals)
fn is_constant_expr(expr: &LogicalExpr) -> bool {
match expr {
LogicalExpr::Literal(_) => true,
LogicalExpr::BinaryExpr { left, right, .. } => {
Self::is_constant_expr(left) && Self::is_constant_expr(right)
}
LogicalExpr::UnaryExpr { expr, .. } => Self::is_constant_expr(expr),
LogicalExpr::ScalarFunction { args, .. } => args.iter().all(Self::is_constant_expr),
LogicalExpr::Alias { expr, .. } => Self::is_constant_expr(expr),
_ => false,
}
}
// Check if an expression is valid in GROUP BY context
// An expression is valid if it's:
// 1. A constant literal
// 2. An aggregate function
// 3. A grouping column (or expression involving only grouping columns)
fn is_valid_in_group_by(expr: &LogicalExpr, group_exprs: &[LogicalExpr]) -> bool {
// First check if the entire expression appears in GROUP BY
// Strip aliases before comparing since SELECT might have aliases but GROUP BY might not
let expr_stripped = strip_alias(expr);
if group_exprs.iter().any(|g| expr_stripped == strip_alias(g)) {
return true;
}
// If not, check recursively based on expression type
match expr {
LogicalExpr::Literal(_) => true, // Constants are always valid
LogicalExpr::AggregateFunction { .. } => true, // Aggregates are valid
LogicalExpr::Column(col) => {
// Check if this column is in the GROUP BY
group_exprs.iter().any(|g| match g {
LogicalExpr::Column(gcol) => gcol.name == col.name,
_ => false,
})
}
LogicalExpr::BinaryExpr { left, right, .. } => {
// Both sides must be valid
Self::is_valid_in_group_by(left, group_exprs)
&& Self::is_valid_in_group_by(right, group_exprs)
}
LogicalExpr::UnaryExpr { expr, .. } => Self::is_valid_in_group_by(expr, group_exprs),
LogicalExpr::ScalarFunction { args, .. } => {
// All arguments must be valid
args.iter()
.all(|arg| Self::is_valid_in_group_by(arg, group_exprs))
}
LogicalExpr::Alias { expr, .. } => Self::is_valid_in_group_by(expr, group_exprs),
_ => false, // Other expressions are not valid
}
}
// Extract aggregates from an expression and replace them with column references, with deduplication
// Returns the modified expression and a list of NEW (aggregate_expr, column_name) pairs
fn extract_and_replace_aggregates_with_dedup(
expr: LogicalExpr,
aggregate_map: &mut HashMap<String, String>,
) -> Result<(LogicalExpr, Vec<(LogicalExpr, String)>)> {
let mut new_aggregates = Vec::new();
let mut counter = aggregate_map.len();
let new_expr = Self::replace_aggregates_with_columns_dedup(
expr,
&mut new_aggregates,
aggregate_map,
&mut counter,
)?;
Ok((new_expr, new_aggregates))
}
// Recursively replace aggregate functions with column references, with deduplication
fn replace_aggregates_with_columns_dedup(
expr: LogicalExpr,
new_aggregates: &mut Vec<(LogicalExpr, String)>,
aggregate_map: &mut HashMap<String, String>,
counter: &mut usize,
) -> Result<LogicalExpr> {
match expr {
LogicalExpr::AggregateFunction { .. } => {
// Found an aggregate - check if we've seen it before
let agg_key = format!("{expr:?}");
let col_name = if let Some(existing_name) = aggregate_map.get(&agg_key) {
// Reuse existing aggregate
existing_name.clone()
} else {
// New aggregate
let col_name = format!("__agg_{}", *counter);
*counter += 1;
aggregate_map.insert(agg_key, col_name.clone());
new_aggregates.push((expr, col_name.clone()));
col_name
};
Ok(LogicalExpr::Column(Column {
name: col_name,
table: None,
}))
}
LogicalExpr::BinaryExpr { left, op, right } => {
let new_left = Self::replace_aggregates_with_columns_dedup(
*left,
new_aggregates,
aggregate_map,
counter,
)?;
let new_right = Self::replace_aggregates_with_columns_dedup(
*right,
new_aggregates,
aggregate_map,
counter,
)?;
Ok(LogicalExpr::BinaryExpr {
left: Box::new(new_left),
op,
right: Box::new(new_right),
})
}
LogicalExpr::UnaryExpr { op, expr } => {
let new_expr = Self::replace_aggregates_with_columns_dedup(
*expr,
new_aggregates,
aggregate_map,
counter,
)?;
Ok(LogicalExpr::UnaryExpr {
op,
expr: Box::new(new_expr),
})
}
LogicalExpr::ScalarFunction { fun, args } => {
let mut new_args = Vec::new();
for arg in args {
new_args.push(Self::replace_aggregates_with_columns_dedup(
arg,
new_aggregates,
aggregate_map,
counter,
)?);
}
Ok(LogicalExpr::ScalarFunction {
fun,
args: new_args,
})
}
LogicalExpr::Case {
expr: case_expr,
when_then,
else_expr,
} => {
let new_case_expr = if let Some(e) = case_expr {
Some(Box::new(Self::replace_aggregates_with_columns_dedup(
*e,
new_aggregates,
aggregate_map,
counter,
)?))
} else {
None
};
let mut new_when_then = Vec::new();
for (when, then) in when_then {
let new_when = Self::replace_aggregates_with_columns_dedup(
when,
new_aggregates,
aggregate_map,
counter,
)?;
let new_then = Self::replace_aggregates_with_columns_dedup(
then,
new_aggregates,
aggregate_map,
counter,
)?;
new_when_then.push((new_when, new_then));
}
let new_else = if let Some(e) = else_expr {
Some(Box::new(Self::replace_aggregates_with_columns_dedup(
*e,
new_aggregates,
aggregate_map,
counter,
)?))
} else {
None
};
Ok(LogicalExpr::Case {
expr: new_case_expr,
when_then: new_when_then,
else_expr: new_else,
})
}
LogicalExpr::Alias { expr, alias } => {
let new_expr = Self::replace_aggregates_with_columns_dedup(
*expr,
new_aggregates,
aggregate_map,
counter,
)?;
Ok(LogicalExpr::Alias {
expr: Box::new(new_expr),
alias,
})
}
// Other expressions - keep as is
_ => Ok(expr),
}
}
// Get column name from expression
fn expr_to_column_name(expr: &ast::Expr) -> String {
match expr {
ast::Expr::Id(name) => Self::name_to_string(name),
ast::Expr::Qualified(_, col) => Self::name_to_string(col),
ast::Expr::FunctionCall { name, .. } => Self::name_to_string(name),
ast::Expr::FunctionCallStar { name, .. } => {
format!("{}(*)", Self::name_to_string(name))
}
_ => "expr".to_string(),
}
}
// Get table schema
fn get_table_schema(&self, table_name: &str, alias: Option<&str>) -> Result<SchemaRef> {
// Look up table in schema
let table = self
.schema
.get_table(table_name)
.ok_or_else(|| LimboError::ParseError(format!("Table '{table_name}' not found")))?;
// Parse table_name which might be "db.table" for attached databases
let (database, actual_table) = if table_name.contains('.') {
let parts: Vec<&str> = table_name.splitn(2, '.').collect();
(Some(parts[0].to_string()), parts[1].to_string())
} else {
(None, table_name.to_string())
};
let mut columns = Vec::new();
for col in table.columns() {
if let Some(ref name) = col.name {
columns.push(ColumnInfo {
name: name.clone(),
ty: col.ty,
database: database.clone(),
table: Some(actual_table.clone()),
table_alias: alias.map(|s| s.to_string()),
});
}
}
Ok(Arc::new(LogicalSchema::new(columns)))
}
// Infer expression type
fn infer_expr_type(expr: &LogicalExpr, schema: &SchemaRef) -> Result<Type> {
match expr {
LogicalExpr::Column(col) => {
if let Some((_, col_info)) = schema.find_column(&col.name, col.table.as_deref()) {
Ok(col_info.ty)
} else {
Ok(Type::Text)
}
}
LogicalExpr::Literal(Value::Integer(_)) => Ok(Type::Integer),
LogicalExpr::Literal(Value::Float(_)) => Ok(Type::Real),
LogicalExpr::Literal(Value::Text(_)) => Ok(Type::Text),
LogicalExpr::Literal(Value::Null) => Ok(Type::Null),
LogicalExpr::Literal(Value::Blob(_)) => Ok(Type::Blob),
LogicalExpr::BinaryExpr { op, left, right } => {
match op {
ast::Operator::Add | ast::Operator::Subtract | ast::Operator::Multiply => {
// Infer types of operands to match SQLite/Numeric behavior
let left_type = Self::infer_expr_type(left, schema)?;
let right_type = Self::infer_expr_type(right, schema)?;
// Integer op Integer = Integer (matching core/numeric/mod.rs behavior)
// Any operation with Real = Real
match (left_type, right_type) {
(Type::Integer, Type::Integer) => Ok(Type::Integer),
(Type::Integer, Type::Real)
| (Type::Real, Type::Integer)
| (Type::Real, Type::Real) => Ok(Type::Real),
(Type::Null, _) | (_, Type::Null) => Ok(Type::Null),
// For Text/Blob, SQLite coerces to numeric, defaulting to Real
_ => Ok(Type::Real),
}
}
ast::Operator::Divide => {
// Division always produces Real in SQLite
Ok(Type::Real)
}
ast::Operator::Modulus => {
// Modulus follows same rules as other arithmetic ops
let left_type = Self::infer_expr_type(left, schema)?;
let right_type = Self::infer_expr_type(right, schema)?;
match (left_type, right_type) {
(Type::Integer, Type::Integer) => Ok(Type::Integer),
_ => Ok(Type::Real),
}
}
ast::Operator::Equals
| ast::Operator::NotEquals
| ast::Operator::Less
| ast::Operator::LessEquals
| ast::Operator::Greater
| ast::Operator::GreaterEquals
| ast::Operator::And
| ast::Operator::Or
| ast::Operator::Is
| ast::Operator::IsNot => Ok(Type::Integer),
ast::Operator::Concat => Ok(Type::Text),
_ => Ok(Type::Text), // Default for other operators
}
}
LogicalExpr::UnaryExpr { op, expr } => match op {
ast::UnaryOperator::Not => Ok(Type::Integer),
ast::UnaryOperator::Negative | ast::UnaryOperator::Positive => {
Self::infer_expr_type(expr, schema)
}
ast::UnaryOperator::BitwiseNot => Ok(Type::Integer),
},
LogicalExpr::AggregateFunction { fun, .. } => match fun {
AggFunc::Count | AggFunc::Count0 => Ok(Type::Integer),
AggFunc::Sum | AggFunc::Avg | AggFunc::Total => Ok(Type::Real),
AggFunc::Min | AggFunc::Max => Ok(Type::Text),
AggFunc::GroupConcat | AggFunc::StringAgg => Ok(Type::Text),
#[cfg(feature = "json")]
AggFunc::JsonbGroupArray
| AggFunc::JsonGroupArray
| AggFunc::JsonbGroupObject
| AggFunc::JsonGroupObject => Ok(Type::Text),
AggFunc::External(_) => Ok(Type::Text), // Default for external
},
LogicalExpr::Alias { expr, .. } => Self::infer_expr_type(expr, schema),
LogicalExpr::IsNull { .. } => Ok(Type::Integer),
LogicalExpr::InList { .. } | LogicalExpr::InSubquery { .. } => Ok(Type::Integer),
LogicalExpr::Exists { .. } => Ok(Type::Integer),
LogicalExpr::Between { .. } => Ok(Type::Integer),
LogicalExpr::Like { .. } => Ok(Type::Integer),
_ => Ok(Type::Text),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type};
use turso_parser::parser::Parser;
fn create_test_schema() -> Schema {
let mut schema = Schema::new(false);
// Create users table
let users_table = BTreeTable {
name: "users".to_string(),
root_page: 2,
primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)],
foreign_keys: vec![],
columns: vec![
SchemaColumn {
name: Some("id".to_string()),
ty: Type::Integer,
ty_str: "INTEGER".to_string(),
primary_key: true,
is_rowid_alias: true,
notnull: true,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("name".to_string()),
ty: Type::Text,
ty_str: "TEXT".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("age".to_string()),
ty: Type::Integer,
ty_str: "INTEGER".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("email".to_string()),
ty: Type::Text,
ty_str: "TEXT".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
],
has_rowid: true,
is_strict: false,
has_autoincrement: false,
unique_sets: vec![],
};
schema
.add_btree_table(Arc::new(users_table))
.expect("Test setup: failed to add users table");
// Create orders table
let orders_table = BTreeTable {
name: "orders".to_string(),
root_page: 3,
primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)],
columns: vec![
SchemaColumn {
name: Some("id".to_string()),
ty: Type::Integer,
ty_str: "INTEGER".to_string(),
primary_key: true,
is_rowid_alias: true,
notnull: true,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("user_id".to_string()),
ty: Type::Integer,
ty_str: "INTEGER".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("product".to_string()),
ty: Type::Text,
ty_str: "TEXT".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("amount".to_string()),
ty: Type::Real,
ty_str: "REAL".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
],
has_rowid: true,
is_strict: false,
has_autoincrement: false,
unique_sets: vec![],
foreign_keys: vec![],
};
schema
.add_btree_table(Arc::new(orders_table))
.expect("Test setup: failed to add orders table");
// Create products table
let products_table = BTreeTable {
name: "products".to_string(),
root_page: 4,
primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)],
columns: vec![
SchemaColumn {
name: Some("id".to_string()),
ty: Type::Integer,
ty_str: "INTEGER".to_string(),
primary_key: true,
is_rowid_alias: true,
notnull: true,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("name".to_string()),
ty: Type::Text,
ty_str: "TEXT".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("price".to_string()),
ty: Type::Real,
ty_str: "REAL".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("product_id".to_string()),
ty: Type::Integer,
ty_str: "INTEGER".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
],
has_rowid: true,
is_strict: false,
has_autoincrement: false,
unique_sets: vec![],
foreign_keys: vec![],
};
schema
.add_btree_table(Arc::new(products_table))
.expect("Test setup: failed to add products table");
schema
}
fn parse_and_build(sql: &str, schema: &Schema) -> Result<LogicalPlan> {
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser
.next()
.ok_or_else(|| LimboError::ParseError("Empty statement".to_string()))?
.map_err(|e| LimboError::ParseError(e.to_string()))?;
match cmd {
ast::Cmd::Stmt(stmt) => {
let mut builder = LogicalPlanBuilder::new(schema);
builder.build_statement(&stmt)
}
_ => Err(LimboError::ParseError(
"Only SQL statements are supported".to_string(),
)),
}
}
#[test]
fn test_simple_select() {
let schema = create_test_schema();
let sql = "SELECT id, name FROM users";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 2);
assert!(matches!(proj.exprs[0], LogicalExpr::Column(_)));
assert!(matches!(proj.exprs[1], LogicalExpr::Column(_)));
match &*proj.input {
LogicalPlan::TableScan(scan) => {
assert_eq!(scan.table_name, "users");
}
_ => panic!("Expected TableScan"),
}
}
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_select_with_filter() {
let schema = create_test_schema();
let sql = "SELECT name FROM users WHERE age > 18";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 1);
match &*proj.input {
LogicalPlan::Filter(filter) => {
assert!(matches!(
filter.predicate,
LogicalExpr::BinaryExpr {
op: ast::Operator::Greater,
..
}
));
match &*filter.input {
LogicalPlan::TableScan(scan) => {
assert_eq!(scan.table_name, "users");
}
_ => panic!("Expected TableScan"),
}
}
_ => panic!("Expected Filter"),
}
}
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_aggregate_with_group_by() {
let schema = create_test_schema();
let sql = "SELECT user_id, SUM(amount) FROM orders GROUP BY user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.group_expr.len(), 1);
assert_eq!(agg.aggr_expr.len(), 1);
assert_eq!(agg.schema.column_count(), 2);
assert!(matches!(
agg.aggr_expr[0],
LogicalExpr::AggregateFunction {
fun: AggFunc::Sum,
..
}
));
match &*agg.input {
LogicalPlan::TableScan(scan) => {
assert_eq!(scan.table_name, "orders");
}
_ => panic!("Expected TableScan"),
}
}
_ => panic!("Expected Aggregate (no projection)"),
}
}
#[test]
fn test_aggregate_without_group_by() {
let schema = create_test_schema();
let sql = "SELECT COUNT(*), MAX(age) FROM users";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.group_expr.len(), 0);
assert_eq!(agg.aggr_expr.len(), 2);
assert_eq!(agg.schema.column_count(), 2);
assert!(matches!(
agg.aggr_expr[0],
LogicalExpr::AggregateFunction {
fun: AggFunc::Count,
..
}
));
assert!(matches!(
agg.aggr_expr[1],
LogicalExpr::AggregateFunction {
fun: AggFunc::Max,
..
}
));
}
_ => panic!("Expected Aggregate (no projection)"),
}
}
#[test]
fn test_order_by() {
let schema = create_test_schema();
let sql = "SELECT name FROM users ORDER BY age DESC, name ASC";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Sort(sort) => {
assert_eq!(sort.exprs.len(), 2);
assert!(!sort.exprs[0].asc); // DESC
assert!(sort.exprs[1].asc); // ASC
match &*sort.input {
LogicalPlan::Projection(_) => {}
_ => panic!("Expected Projection"),
}
}
_ => panic!("Expected Sort"),
}
}
#[test]
fn test_limit_offset() {
let schema = create_test_schema();
let sql = "SELECT * FROM users LIMIT 10 OFFSET 5";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Limit(limit) => {
assert_eq!(limit.fetch, Some(10));
assert_eq!(limit.skip, Some(5));
}
_ => panic!("Expected Limit"),
}
}
#[test]
fn test_order_by_with_limit() {
let schema = create_test_schema();
let sql = "SELECT name FROM users ORDER BY age DESC LIMIT 5";
let plan = parse_and_build(sql, &schema).unwrap();
// Should produce: Limit -> Sort -> Projection -> TableScan
match plan {
LogicalPlan::Limit(limit) => {
assert_eq!(limit.fetch, Some(5));
assert_eq!(limit.skip, None);
match &*limit.input {
LogicalPlan::Sort(sort) => {
assert_eq!(sort.exprs.len(), 1);
assert!(!sort.exprs[0].asc); // DESC
match &*sort.input {
LogicalPlan::Projection(_) => {}
_ => panic!("Expected Projection under Sort"),
}
}
_ => panic!("Expected Sort under Limit"),
}
}
_ => panic!("Expected Limit at top level"),
}
}
#[test]
fn test_distinct() {
let schema = create_test_schema();
let sql = "SELECT DISTINCT name FROM users";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Distinct(distinct) => match &*distinct.input {
LogicalPlan::Projection(_) => {}
_ => panic!("Expected Projection"),
},
_ => panic!("Expected Distinct"),
}
}
#[test]
fn test_union() {
let schema = create_test_schema();
let sql = "SELECT id FROM users UNION SELECT user_id FROM orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Union(union) => {
assert!(!union.all);
assert_eq!(union.inputs.len(), 2);
}
_ => panic!("Expected Union"),
}
}
#[test]
fn test_union_all() {
let schema = create_test_schema();
let sql = "SELECT id FROM users UNION ALL SELECT user_id FROM orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Union(union) => {
assert!(union.all);
assert_eq!(union.inputs.len(), 2);
}
_ => panic!("Expected Union"),
}
}
#[test]
fn test_union_with_order_by() {
let schema = create_test_schema();
let sql = "SELECT id, name FROM users UNION SELECT user_id, name FROM orders ORDER BY id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Sort(sort) => {
assert_eq!(sort.exprs.len(), 1);
assert!(sort.exprs[0].asc); // Default ASC
match &*sort.input {
LogicalPlan::Union(union) => {
assert!(!union.all); // UNION (not UNION ALL)
assert_eq!(union.inputs.len(), 2);
}
_ => panic!("Expected Union under Sort"),
}
}
_ => panic!("Expected Sort at top level"),
}
}
#[test]
fn test_with_cte() {
let schema = create_test_schema();
let sql = "WITH active_users AS (SELECT * FROM users WHERE age > 18) SELECT name FROM active_users";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::WithCTE(with) => {
assert_eq!(with.ctes.len(), 1);
assert!(with.ctes.contains_key("active_users"));
let cte = &with.ctes["active_users"];
match &**cte {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Filter(_) => {}
_ => panic!("Expected Filter in CTE"),
},
_ => panic!("Expected Projection in CTE"),
}
match &*with.body {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::CTERef(cte_ref) => {
assert_eq!(cte_ref.name, "active_users");
}
_ => panic!("Expected CTERef"),
},
_ => panic!("Expected Projection in body"),
}
}
_ => panic!("Expected WithCTE"),
}
}
#[test]
fn test_case_expression() {
let schema = create_test_schema();
let sql = "SELECT CASE WHEN age < 18 THEN 'minor' WHEN age < 65 THEN 'adult' ELSE 'senior' END FROM users";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 1);
assert!(matches!(proj.exprs[0], LogicalExpr::Case { .. }));
}
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_in_list() {
let schema = create_test_schema();
let sql = "SELECT * FROM users WHERE id IN (1, 2, 3)";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Filter(filter) => match &filter.predicate {
LogicalExpr::InList { list, negated, .. } => {
assert!(!negated);
assert_eq!(list.len(), 3);
}
_ => panic!("Expected InList"),
},
_ => panic!("Expected Filter"),
},
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_in_subquery() {
let schema = create_test_schema();
let sql = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Filter(filter) => {
assert!(matches!(filter.predicate, LogicalExpr::InSubquery { .. }));
}
_ => panic!("Expected Filter"),
},
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_exists_subquery() {
let schema = create_test_schema();
let sql = "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Filter(filter) => {
assert!(matches!(filter.predicate, LogicalExpr::Exists { .. }));
}
_ => panic!("Expected Filter"),
},
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_between() {
let schema = create_test_schema();
let sql = "SELECT * FROM users WHERE age BETWEEN 18 AND 65";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Filter(filter) => match &filter.predicate {
LogicalExpr::Between { negated, .. } => {
assert!(!negated);
}
_ => panic!("Expected Between"),
},
_ => panic!("Expected Filter"),
},
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_like() {
let schema = create_test_schema();
let sql = "SELECT * FROM users WHERE name LIKE 'John%'";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Filter(filter) => match &filter.predicate {
LogicalExpr::Like {
negated, escape, ..
} => {
assert!(!negated);
assert!(escape.is_none());
}
_ => panic!("Expected Like"),
},
_ => panic!("Expected Filter"),
},
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_is_null() {
let schema = create_test_schema();
let sql = "SELECT * FROM users WHERE email IS NULL";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Filter(filter) => match &filter.predicate {
LogicalExpr::IsNull { negated, .. } => {
assert!(!negated);
}
_ => panic!("Expected IsNull, got: {:?}", filter.predicate),
},
_ => panic!("Expected Filter"),
},
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_is_not_null() {
let schema = create_test_schema();
let sql = "SELECT * FROM users WHERE email IS NOT NULL";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Filter(filter) => match &filter.predicate {
LogicalExpr::IsNull { negated, .. } => {
assert!(negated);
}
_ => panic!("Expected IsNull"),
},
_ => panic!("Expected Filter"),
},
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_values_clause() {
let schema = create_test_schema();
let sql = "SELECT * FROM (VALUES (1, 'a'), (2, 'b'))";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Values(values) => {
assert_eq!(values.rows.len(), 2);
assert_eq!(values.rows[0].len(), 2);
}
_ => panic!("Expected Values"),
},
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_complex_expression_with_aggregation() {
// Test: SELECT sum(id + 2) * 2 FROM orders GROUP BY user_id
let schema = create_test_schema();
// Test the complex case: sum((id + 2)) * 2 with parentheses
let sql = "SELECT sum((id + 2)) * 2 FROM orders GROUP BY user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 1);
match &proj.exprs[0] {
LogicalExpr::BinaryExpr { left, op, right } => {
assert_eq!(*op, BinaryOperator::Multiply);
assert!(matches!(**left, LogicalExpr::Column(_)));
assert!(matches!(**right, LogicalExpr::Literal(_)));
}
_ => panic!("Expected BinaryExpr in projection"),
}
match &*proj.input {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.group_expr.len(), 1);
assert_eq!(agg.aggr_expr.len(), 1);
match &agg.aggr_expr[0] {
LogicalExpr::AggregateFunction { fun, args, .. } => {
assert_eq!(*fun, AggregateFunction::Sum);
assert_eq!(args.len(), 1);
match &args[0] {
LogicalExpr::Column(col) => {
assert!(col.name.starts_with("__agg_arg_proj_"));
}
_ => panic!("Expected Column reference to projected expression in aggregate args, got {:?}", args[0]),
}
}
_ => panic!("Expected AggregateFunction"),
}
match &*agg.input {
LogicalPlan::Projection(inner_proj) => {
assert!(inner_proj.exprs.len() >= 2);
let has_binary_add = inner_proj.exprs.iter().any(|e| {
matches!(
e,
LogicalExpr::BinaryExpr {
op: BinaryOperator::Add,
..
}
)
});
assert!(
has_binary_add,
"Should have id + 2 expression in inner projection"
);
}
_ => panic!("Expected Projection as input to Aggregate"),
}
}
_ => panic!("Expected Aggregate under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_function_on_aggregate_result() {
let schema = create_test_schema();
let sql = "SELECT abs(sum(id)) FROM orders GROUP BY user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 1);
match &proj.exprs[0] {
LogicalExpr::ScalarFunction { fun, args } => {
assert_eq!(fun, "abs");
assert_eq!(args.len(), 1);
assert!(matches!(args[0], LogicalExpr::Column(_)));
}
_ => panic!("Expected ScalarFunction in projection"),
}
}
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_multiple_aggregates_with_arithmetic() {
let schema = create_test_schema();
let sql = "SELECT sum(id) * 2 + count(*) FROM orders GROUP BY user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 1);
match &proj.exprs[0] {
LogicalExpr::BinaryExpr { op, .. } => {
assert_eq!(*op, BinaryOperator::Add);
}
_ => panic!("Expected BinaryExpr"),
}
match &*proj.input {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.aggr_expr.len(), 2);
}
_ => panic!("Expected Aggregate"),
}
}
_ => panic!("Expected Projection"),
}
}
#[test]
fn test_projection_aggregation_projection() {
let schema = create_test_schema();
// This tests: projection -> aggregation -> projection
// The inner projection computes (id + 2), then we aggregate sum(), then apply abs()
let sql = "SELECT abs(sum(id + 2)) FROM orders GROUP BY user_id";
let plan = parse_and_build(sql, &schema).unwrap();
// Should produce: Projection(abs) -> Aggregate(sum) -> Projection(id + 2) -> TableScan
match plan {
LogicalPlan::Projection(outer_proj) => {
assert_eq!(outer_proj.exprs.len(), 1);
// Outer projection should apply abs() function
match &outer_proj.exprs[0] {
LogicalExpr::ScalarFunction { fun, args } => {
assert_eq!(fun, "abs");
assert_eq!(args.len(), 1);
assert!(matches!(args[0], LogicalExpr::Column(_)));
}
_ => panic!("Expected abs() function in outer projection"),
}
// Next should be the Aggregate
match &*outer_proj.input {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.group_expr.len(), 1);
assert_eq!(agg.aggr_expr.len(), 1);
// The aggregate should be summing a column reference
match &agg.aggr_expr[0] {
LogicalExpr::AggregateFunction { fun, args, .. } => {
assert_eq!(*fun, AggregateFunction::Sum);
assert_eq!(args.len(), 1);
// Should reference the projected column
match &args[0] {
LogicalExpr::Column(col) => {
assert!(col.name.starts_with("__agg_arg_proj_"));
}
_ => panic!("Expected column reference in aggregate"),
}
}
_ => panic!("Expected AggregateFunction"),
}
// Input to aggregate should be a projection computing id + 2
match &*agg.input {
LogicalPlan::Projection(inner_proj) => {
// Should have at least the group column and the computed expression
assert!(inner_proj.exprs.len() >= 2);
// Check for the id + 2 expression
let has_add_expr = inner_proj.exprs.iter().any(|e| {
matches!(
e,
LogicalExpr::BinaryExpr {
op: BinaryOperator::Add,
..
}
)
});
assert!(
has_add_expr,
"Should have id + 2 expression in inner projection"
);
}
_ => panic!("Expected inner Projection under Aggregate"),
}
}
_ => panic!("Expected Aggregate under outer Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_group_by_validation_allow_grouped_column() {
let schema = create_test_schema();
// Test that grouped columns are allowed
let sql = "SELECT user_id, COUNT(*) FROM orders GROUP BY user_id";
let result = parse_and_build(sql, &schema);
assert!(result.is_ok(), "Should allow grouped column in SELECT");
}
#[test]
fn test_group_by_validation_allow_constants() {
let schema = create_test_schema();
// Test that simple constants are allowed even when not grouped
let sql = "SELECT user_id, 42, COUNT(*) FROM orders GROUP BY user_id";
let result = parse_and_build(sql, &schema);
assert!(
result.is_ok(),
"Should allow simple constants in SELECT with GROUP BY"
);
let sql_complex = "SELECT user_id, (100 + 50) * 2, COUNT(*) FROM orders GROUP BY user_id";
let result_complex = parse_and_build(sql_complex, &schema);
assert!(
result_complex.is_ok(),
"Should allow complex constant expressions in SELECT with GROUP BY"
);
}
#[test]
fn test_parenthesized_aggregate_expressions() {
let schema = create_test_schema();
let sql = "SELECT 25, (MAX(id) / 3), 39 FROM orders";
let result = parse_and_build(sql, &schema);
assert!(
result.is_ok(),
"Should handle parenthesized aggregate expressions"
);
let plan = result.unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 3);
assert!(matches!(
proj.exprs[0],
LogicalExpr::Literal(Value::Integer(25))
));
match &proj.exprs[1] {
LogicalExpr::BinaryExpr { left, op, right } => {
assert_eq!(*op, BinaryOperator::Divide);
assert!(matches!(&**left, LogicalExpr::Column(_)));
assert!(matches!(&**right, LogicalExpr::Literal(Value::Integer(3))));
}
_ => panic!("Expected BinaryExpr for (MAX(id) / 3)"),
}
assert!(matches!(
proj.exprs[2],
LogicalExpr::Literal(Value::Integer(39))
));
match &*proj.input {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.aggr_expr.len(), 1);
assert!(matches!(
agg.aggr_expr[0],
LogicalExpr::AggregateFunction {
fun: AggFunc::Max,
..
}
));
}
_ => panic!("Expected Aggregate node under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_duplicate_aggregate_reuse() {
let schema = create_test_schema();
let sql = "SELECT (COUNT(*) - 225), 30, COUNT(*) FROM orders";
let result = parse_and_build(sql, &schema);
assert!(result.is_ok(), "Should handle duplicate aggregates");
let plan = result.unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 3);
match &proj.exprs[0] {
LogicalExpr::BinaryExpr { left, op, right } => {
assert_eq!(*op, BinaryOperator::Subtract);
match &**left {
LogicalExpr::Column(col) => {
assert!(col.name.starts_with("__agg_") || col.name == "COUNT(*)");
}
_ => panic!("Expected Column reference for COUNT(*)"),
}
assert!(matches!(
&**right,
LogicalExpr::Literal(Value::Integer(225))
));
}
_ => panic!("Expected BinaryExpr for (COUNT(*) - 225)"),
}
assert!(matches!(
proj.exprs[1],
LogicalExpr::Literal(Value::Integer(30))
));
match &proj.exprs[2] {
LogicalExpr::Column(col) => {
assert!(col.name.starts_with("__agg_") || col.name == "COUNT(*)");
}
_ => panic!("Expected Column reference for COUNT(*)"),
}
match &*proj.input {
LogicalPlan::Aggregate(agg) => {
assert_eq!(
agg.aggr_expr.len(),
1,
"Should have only one COUNT(*) aggregate"
);
assert!(matches!(
agg.aggr_expr[0],
LogicalExpr::AggregateFunction {
fun: AggFunc::Count,
..
}
));
}
_ => panic!("Expected Aggregate node under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_aggregate_without_group_by_allow_constants() {
let schema = create_test_schema();
// Test that constants are allowed with aggregates even without GROUP BY
let sql = "SELECT 42, COUNT(*), MAX(amount) FROM orders";
let result = parse_and_build(sql, &schema);
assert!(
result.is_ok(),
"Should allow simple constants with aggregates without GROUP BY"
);
// Test complex constant expressions
let sql_complex = "SELECT (9 / 6) % 5, COUNT(*), MAX(amount) FROM orders";
let result_complex = parse_and_build(sql_complex, &schema);
assert!(
result_complex.is_ok(),
"Should allow complex constant expressions with aggregates without GROUP BY"
);
}
#[test]
fn test_aggregate_without_group_by_creates_aggregate_node() {
let schema = create_test_schema();
// Test that aggregate without GROUP BY creates proper Aggregate node
let sql = "SELECT MAX(amount) FROM orders";
let plan = parse_and_build(sql, &schema).unwrap();
// Should be: Aggregate -> TableScan (no projection needed for simple aggregate)
match plan {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.group_expr.len(), 0, "Should have no group expressions");
assert_eq!(
agg.aggr_expr.len(),
1,
"Should have one aggregate expression"
);
assert_eq!(
agg.schema.column_count(),
1,
"Schema should have one column"
);
}
_ => panic!("Expected Aggregate at top level (no projection)"),
}
}
#[test]
fn test_scalar_vs_aggregate_function_classification() {
let schema = create_test_schema();
// Test MIN/MAX with 1 argument - should be aggregate
let sql = "SELECT MIN(amount) FROM orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.aggr_expr.len(), 1, "MIN(x) should be an aggregate");
match &agg.aggr_expr[0] {
LogicalExpr::AggregateFunction { fun, args, .. } => {
assert!(matches!(fun, AggFunc::Min));
assert_eq!(args.len(), 1);
}
_ => panic!("Expected AggregateFunction"),
}
}
_ => panic!("Expected Aggregate node for MIN(x)"),
}
// Test MIN/MAX with 2 arguments - should be scalar in projection
let sql = "SELECT MIN(amount, user_id) FROM orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 1, "Should have one projection expression");
match &proj.exprs[0] {
LogicalExpr::ScalarFunction { fun, args } => {
assert_eq!(
fun.to_lowercase(),
"min",
"MIN(x,y) should be a scalar function"
);
assert_eq!(args.len(), 2);
}
_ => panic!("Expected ScalarFunction for MIN(x,y)"),
}
}
_ => panic!("Expected Projection node for scalar MIN(x,y)"),
}
// Test MAX with 3 arguments - should be scalar
let sql = "SELECT MAX(amount, user_id, id) FROM orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 1);
match &proj.exprs[0] {
LogicalExpr::ScalarFunction { fun, args } => {
assert_eq!(
fun.to_lowercase(),
"max",
"MAX(x,y,z) should be a scalar function"
);
assert_eq!(args.len(), 3);
}
_ => panic!("Expected ScalarFunction for MAX(x,y,z)"),
}
}
_ => panic!("Expected Projection node for scalar MAX(x,y,z)"),
}
// Test that MIN with 0 args is treated as scalar (will fail later in execution)
let sql = "SELECT MIN() FROM orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &proj.exprs[0] {
LogicalExpr::ScalarFunction { fun, args } => {
assert_eq!(fun.to_lowercase(), "min");
assert_eq!(args.len(), 0, "MIN() should be scalar with 0 args");
}
_ => panic!("Expected ScalarFunction for MIN()"),
},
_ => panic!("Expected Projection for MIN()"),
}
// Test other functions that are always aggregate (COUNT, SUM, AVG)
let sql = "SELECT COUNT(*), SUM(amount), AVG(amount) FROM orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.aggr_expr.len(), 3, "Should have 3 aggregate functions");
for expr in &agg.aggr_expr {
assert!(matches!(expr, LogicalExpr::AggregateFunction { .. }));
}
}
_ => panic!("Expected Aggregate node"),
}
// Test scalar functions that are never aggregates (ABS, ROUND, etc.)
let sql = "SELECT ABS(amount), ROUND(amount), LENGTH(product) FROM orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 3, "Should have 3 scalar functions");
for expr in &proj.exprs {
match expr {
LogicalExpr::ScalarFunction { .. } => {}
_ => panic!("Expected all ScalarFunctions"),
}
}
}
_ => panic!("Expected Projection node for scalar functions"),
}
}
#[test]
fn test_mixed_aggregate_and_group_columns() {
let schema = create_test_schema();
// When selecting both aggregate and grouping columns
let sql = "SELECT user_id, sum(id) FROM orders GROUP BY user_id";
let plan = parse_and_build(sql, &schema).unwrap();
// No projection needed - aggregate outputs exactly what we select
match plan {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.group_expr.len(), 1);
assert_eq!(agg.aggr_expr.len(), 1);
assert_eq!(agg.schema.column_count(), 2);
}
_ => panic!("Expected Aggregate (no projection)"),
}
}
#[test]
fn test_scalar_function_wrapping_aggregate_no_group_by() {
// Test: SELECT HEX(SUM(age + 2)) FROM users
// Expected structure:
// Projection { exprs: [ScalarFunction(HEX, [Column])] }
// -> Aggregate { aggr_expr: [Sum(BinaryExpr(age + 2))], group_expr: [] }
// -> Projection { exprs: [BinaryExpr(age + 2)] }
// -> TableScan("users")
let schema = create_test_schema();
let sql = "SELECT HEX(SUM(age + 2)) FROM users";
let mut parser = Parser::new(sql.as_bytes());
let stmt = parser.next().unwrap().unwrap();
let plan = match stmt {
ast::Cmd::Stmt(stmt) => {
let mut builder = LogicalPlanBuilder::new(&schema);
builder.build_statement(&stmt).unwrap()
}
_ => panic!("Expected SQL statement"),
};
match &plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 1, "Should have one expression");
match &proj.exprs[0] {
LogicalExpr::ScalarFunction { fun, args } => {
assert_eq!(fun, "HEX", "Outer function should be HEX");
assert_eq!(args.len(), 1, "HEX should have one argument");
match &args[0] {
LogicalExpr::Column(_) => {}
LogicalExpr::AggregateFunction { .. } => {
panic!("Aggregate function should not be embedded in projection! It should be in a separate Aggregate operator");
}
_ => panic!(
"Expected column reference as argument to HEX, got: {:?}",
args[0]
),
}
}
_ => panic!("Expected ScalarFunction (HEX), got: {:?}", proj.exprs[0]),
}
match &*proj.input {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.group_expr.len(), 0, "Should have no GROUP BY");
assert_eq!(
agg.aggr_expr.len(),
1,
"Should have one aggregate expression"
);
match &agg.aggr_expr[0] {
LogicalExpr::AggregateFunction {
fun,
args,
distinct,
} => {
assert_eq!(*fun, crate::function::AggFunc::Sum, "Should be SUM");
assert!(!distinct, "Should not be DISTINCT");
assert_eq!(args.len(), 1, "SUM should have one argument");
match &args[0] {
LogicalExpr::Column(col) => {
// When aggregate arguments are complex, they get pre-projected
assert!(col.name.starts_with("__agg_arg_proj_"),
"Should reference pre-projected column, got: {}", col.name);
}
LogicalExpr::BinaryExpr { left, op, right } => {
// Simple case without pre-projection (shouldn't happen with current implementation)
assert_eq!(*op, ast::Operator::Add, "Should be addition");
match (&**left, &**right) {
(LogicalExpr::Column(col), LogicalExpr::Literal(val)) => {
assert_eq!(col.name, "age", "Should reference age column");
assert_eq!(*val, Value::Integer(2), "Should add 2");
}
_ => panic!("Expected age + 2"),
}
}
_ => panic!("Expected Column reference or BinaryExpr for aggregate argument, got: {:?}", args[0]),
}
}
_ => panic!("Expected AggregateFunction"),
}
match &*agg.input {
LogicalPlan::TableScan(scan) => {
assert_eq!(scan.table_name, "users");
}
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::TableScan(scan) => {
assert_eq!(scan.table_name, "users");
}
_ => panic!("Expected TableScan under projection"),
},
_ => panic!("Expected TableScan or Projection under Aggregate"),
}
}
_ => panic!(
"Expected Aggregate operator under Projection, got: {:?}",
proj.input
),
}
}
_ => panic!("Expected Projection as top-level operator, got: {plan:?}"),
}
}
// ===== JOIN TESTS =====
#[test]
fn test_inner_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u INNER JOIN orders o ON u.id = o.user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
assert!(!join.on.is_empty(), "Should have join conditions");
// Check left input is users
match &*join.left {
LogicalPlan::TableScan(scan) => {
assert_eq!(scan.table_name, "users");
}
_ => panic!("Expected TableScan for left input"),
}
// Check right input is orders
match &*join.right {
LogicalPlan::TableScan(scan) => {
assert_eq!(scan.table_name, "orders");
}
_ => panic!("Expected TableScan for right input"),
}
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_left_join() {
let schema = create_test_schema();
let sql = "SELECT u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 2); // name and amount
match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Left);
assert!(!join.on.is_empty(), "Should have join conditions");
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_right_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM orders o RIGHT JOIN users u ON o.user_id = u.id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Right);
assert!(!join.on.is_empty(), "Should have join conditions");
}
_ => panic!("Expected Join under Projection"),
},
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_full_outer_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u FULL OUTER JOIN orders o ON u.id = o.user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Full);
assert!(!join.on.is_empty(), "Should have join conditions");
}
_ => panic!("Expected Join under Projection"),
},
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_cross_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM users CROSS JOIN orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Cross);
assert!(join.on.is_empty(), "Cross join should have no conditions");
assert!(join.filter.is_none(), "Cross join should have no filter");
}
_ => panic!("Expected Join under Projection"),
},
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_with_multiple_conditions() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id AND u.age > 18";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
// Should have at least one equijoin condition
assert!(!join.on.is_empty(), "Should have join conditions");
// Additional conditions may be in filter
// The exact distribution depends on our implementation
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_using_clause() {
let schema = create_test_schema();
// Note: Both tables should have an 'id' column for this to work
let sql = "SELECT * FROM users JOIN orders USING (id)";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
assert!(
!join.on.is_empty(),
"USING clause should create join conditions"
);
}
_ => panic!("Expected Join under Projection"),
},
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_natural_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM users NATURAL JOIN orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Join(join) => {
// Natural join finds common columns (id in this case)
// If no common columns, it becomes a cross join
assert!(
!join.on.is_empty() || join.join_type == JoinType::Cross,
"Natural join should either find common columns or become cross join"
);
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_three_way_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u
JOIN orders o ON u.id = o.user_id
JOIN products p ON o.product_id = p.id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Join(join2) => {
// Second join (with products)
assert_eq!(join2.join_type, JoinType::Inner);
match &*join2.left {
LogicalPlan::Join(join1) => {
// First join (users with orders)
assert_eq!(join1.join_type, JoinType::Inner);
}
_ => panic!("Expected nested Join for three-way join"),
}
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_mixed_join_types() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u
LEFT JOIN orders o ON u.id = o.user_id
INNER JOIN products p ON o.product_id = p.id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Join(join2) => {
// Second join should be INNER
assert_eq!(join2.join_type, JoinType::Inner);
match &*join2.left {
LogicalPlan::Join(join1) => {
// First join should be LEFT
assert_eq!(join1.join_type, JoinType::Left);
}
_ => panic!("Expected nested Join"),
}
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_with_filter() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id WHERE o.amount > 100";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Filter(filter) => {
// WHERE clause creates a Filter above the Join
match &*filter.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
}
_ => panic!("Expected Join under Filter"),
}
}
_ => panic!("Expected Filter under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_with_projection() {
let schema = create_test_schema();
let sql = "SELECT u.name, o.amount FROM users u JOIN orders o ON u.id = o.user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 2); // u.name and o.amount
match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_with_aggregation() {
let schema = create_test_schema();
let sql = "SELECT u.name, SUM(o.amount)
FROM users u JOIN orders o ON u.id = o.user_id
GROUP BY u.name";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.group_expr.len(), 1); // GROUP BY u.name
assert_eq!(agg.aggr_expr.len(), 1); // SUM(o.amount)
match &*agg.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
}
_ => panic!("Expected Join under Aggregate"),
}
}
_ => panic!("Expected Aggregate"),
}
}
#[test]
fn test_join_with_order_by() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id ORDER BY o.amount DESC";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Sort(sort) => {
assert_eq!(sort.exprs.len(), 1);
assert!(!sort.exprs[0].asc); // DESC
match &*sort.input {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
}
_ => panic!("Expected Join under Projection"),
},
_ => panic!("Expected Projection under Sort"),
}
}
_ => panic!("Expected Sort at top level"),
}
}
#[test]
fn test_join_in_subquery() {
let schema = create_test_schema();
let sql = "SELECT * FROM (
SELECT u.id, u.name, o.amount
FROM users u JOIN orders o ON u.id = o.user_id
) WHERE amount > 100";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(outer_proj) => match &*outer_proj.input {
LogicalPlan::Filter(filter) => match &*filter.input {
LogicalPlan::Projection(inner_proj) => match &*inner_proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
}
_ => panic!("Expected Join in subquery"),
},
_ => panic!("Expected Projection for subquery"),
},
_ => panic!("Expected Filter"),
},
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_ambiguous_column() {
let schema = create_test_schema();
// Both users and orders have an 'id' column
let sql = "SELECT id FROM users JOIN orders ON users.id = orders.user_id";
let result = parse_and_build(sql, &schema);
// This might error or succeed depending on how we handle ambiguous columns
// For now, just check that parsing completes
match result {
Ok(_) => {
// If successful, the implementation handles ambiguous columns somehow
}
Err(_) => {
// If error, the implementation rejects ambiguous columns
}
}
}
// Tests for strip_alias function
#[test]
fn test_strip_alias_with_alias() {
let inner_expr = LogicalExpr::Column(Column::new("test"));
let aliased = LogicalExpr::Alias {
expr: Box::new(inner_expr.clone()),
alias: "my_alias".to_string(),
};
let stripped = strip_alias(&aliased);
assert_eq!(stripped, &inner_expr);
}
#[test]
fn test_strip_alias_without_alias() {
let expr = LogicalExpr::Column(Column::new("test"));
let stripped = strip_alias(&expr);
assert_eq!(stripped, &expr);
}
#[test]
fn test_strip_alias_literal() {
let expr = LogicalExpr::Literal(Value::Integer(42));
let stripped = strip_alias(&expr);
assert_eq!(stripped, &expr);
}
#[test]
fn test_strip_alias_scalar_function() {
let expr = LogicalExpr::ScalarFunction {
fun: "substr".to_string(),
args: vec![
LogicalExpr::Column(Column::new("name")),
LogicalExpr::Literal(Value::Integer(1)),
LogicalExpr::Literal(Value::Integer(4)),
],
};
let stripped = strip_alias(&expr);
assert_eq!(stripped, &expr);
}
#[test]
fn test_strip_alias_nested_alias() {
// Test that strip_alias only removes the outermost alias
let inner_expr = LogicalExpr::Column(Column::new("test"));
let inner_alias = LogicalExpr::Alias {
expr: Box::new(inner_expr.clone()),
alias: "inner_alias".to_string(),
};
let outer_alias = LogicalExpr::Alias {
expr: Box::new(inner_alias.clone()),
alias: "outer_alias".to_string(),
};
let stripped = strip_alias(&outer_alias);
assert_eq!(stripped, &inner_alias);
// Stripping again should give us the inner expression
let double_stripped = strip_alias(stripped);
assert_eq!(double_stripped, &inner_expr);
}
#[test]
fn test_strip_alias_comparison_with_alias() {
// Test that two expressions match when one has an alias and one doesn't
let base_expr = LogicalExpr::ScalarFunction {
fun: "substr".to_string(),
args: vec![
LogicalExpr::Column(Column::new("orderdate")),
LogicalExpr::Literal(Value::Integer(1)),
LogicalExpr::Literal(Value::Integer(4)),
],
};
let aliased_expr = LogicalExpr::Alias {
expr: Box::new(base_expr.clone()),
alias: "year".to_string(),
};
// Without strip_alias, they wouldn't match
assert_ne!(&aliased_expr, &base_expr);
// With strip_alias, they should match
assert_eq!(strip_alias(&aliased_expr), &base_expr);
assert_eq!(strip_alias(&base_expr), &base_expr);
}
#[test]
fn test_strip_alias_binary_expr() {
let expr = LogicalExpr::BinaryExpr {
left: Box::new(LogicalExpr::Column(Column::new("a"))),
op: BinaryOperator::Add,
right: Box::new(LogicalExpr::Literal(Value::Integer(1))),
};
let stripped = strip_alias(&expr);
assert_eq!(stripped, &expr);
}
#[test]
fn test_strip_alias_aggregate_function() {
let expr = LogicalExpr::AggregateFunction {
fun: AggFunc::Sum,
args: vec![LogicalExpr::Column(Column::new("amount"))],
distinct: false,
};
let stripped = strip_alias(&expr);
assert_eq!(stripped, &expr);
}
#[test]
fn test_strip_alias_comparison_multiple_expressions() {
// Test comparing a list of expressions with and without aliases
let expr1 = LogicalExpr::Column(Column::new("a"));
let expr2 = LogicalExpr::ScalarFunction {
fun: "substr".to_string(),
args: vec![
LogicalExpr::Column(Column::new("b")),
LogicalExpr::Literal(Value::Integer(1)),
LogicalExpr::Literal(Value::Integer(4)),
],
};
let aliased1 = LogicalExpr::Alias {
expr: Box::new(expr1.clone()),
alias: "col_a".to_string(),
};
let aliased2 = LogicalExpr::Alias {
expr: Box::new(expr2.clone()),
alias: "year".to_string(),
};
let select_exprs = [aliased1, aliased2];
let group_exprs = [expr1.clone(), expr2.clone()];
// Verify that stripping aliases allows matching
for (select_expr, group_expr) in select_exprs.iter().zip(group_exprs.iter()) {
assert_eq!(strip_alias(select_expr), group_expr);
}
}
}