mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-19 01:24:20 +01:00
We had code for this, but the code had a fatal flaw: it tried to detect a complex operation (an operation that needs projection), and return false (no need for projection), for the others. This is the exact opposite of what we should do: we should identify the *simple* operations, and then return true (needs projection) for the rest. CAST is a special beast, since it is not a function, but rather, a special opcode. Everything else above is the true just the same. But for CAST, we have to do the extra work to capture it in the logical plan and pass it down. Fixes #3372 Fixes #3370 Fixes #3369
3968 lines
148 KiB
Rust
3968 lines
148 KiB
Rust
//! Logical plan representation for SQL queries
|
|
//!
|
|
//! This module provides a platform-independent intermediate representation
|
|
//! for SQL queries. The logical plan is a DAG (Directed Acyclic Graph) that
|
|
//! supports CTEs and can be used for query optimization before being compiled
|
|
//! to an execution plan (e.g., DBSP circuits).
|
|
//!
|
|
//! The main entry point is `LogicalPlanBuilder` which constructs logical plans
|
|
//! from SQL AST nodes.
|
|
use crate::function::AggFunc;
|
|
use crate::schema::{Schema, Type};
|
|
use crate::types::Value;
|
|
use crate::{LimboError, Result};
|
|
use std::collections::HashMap;
|
|
use std::fmt::{self, Display, Formatter};
|
|
use std::sync::Arc;
|
|
use turso_macros::match_ignore_ascii_case;
|
|
use turso_parser::ast;
|
|
|
|
/// Result type for preprocessing aggregate expressions
|
|
type PreprocessAggregateResult = (
|
|
bool, // needs_pre_projection
|
|
Vec<LogicalExpr>, // pre_projection_exprs
|
|
Vec<ColumnInfo>, // pre_projection_schema
|
|
Vec<LogicalExpr>, // modified_aggr_exprs
|
|
);
|
|
|
|
/// Result type for parsing join conditions
|
|
type JoinConditionsResult = (Vec<(LogicalExpr, LogicalExpr)>, Option<LogicalExpr>);
|
|
|
|
/// Information about a column in a logical schema
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct ColumnInfo {
|
|
pub name: String,
|
|
pub ty: Type,
|
|
pub database: Option<String>,
|
|
pub table: Option<String>,
|
|
pub table_alias: Option<String>,
|
|
}
|
|
|
|
/// Schema information for logical plan nodes
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct LogicalSchema {
|
|
pub columns: Vec<ColumnInfo>,
|
|
}
|
|
/// A reference to a schema that can be shared between nodes
|
|
pub type SchemaRef = Arc<LogicalSchema>;
|
|
|
|
impl LogicalSchema {
|
|
pub fn new(columns: Vec<ColumnInfo>) -> Self {
|
|
Self { columns }
|
|
}
|
|
|
|
pub fn empty() -> Self {
|
|
Self {
|
|
columns: Vec::new(),
|
|
}
|
|
}
|
|
|
|
pub fn column_count(&self) -> usize {
|
|
self.columns.len()
|
|
}
|
|
|
|
pub fn find_column(&self, name: &str, table: Option<&str>) -> Option<(usize, &ColumnInfo)> {
|
|
if let Some(table_ref) = table {
|
|
// Check if it's a database.table format
|
|
if table_ref.contains('.') {
|
|
let parts: Vec<&str> = table_ref.splitn(2, '.').collect();
|
|
if parts.len() == 2 {
|
|
let db = parts[0];
|
|
let tbl = parts[1];
|
|
return self
|
|
.columns
|
|
.iter()
|
|
.position(|c| {
|
|
c.name == name
|
|
&& c.database.as_deref() == Some(db)
|
|
&& c.table.as_deref() == Some(tbl)
|
|
})
|
|
.map(|idx| (idx, &self.columns[idx]));
|
|
}
|
|
}
|
|
|
|
// Try to match against table alias first, then table name
|
|
self.columns
|
|
.iter()
|
|
.position(|c| {
|
|
c.name == name
|
|
&& (c.table_alias.as_deref() == Some(table_ref)
|
|
|| c.table.as_deref() == Some(table_ref))
|
|
})
|
|
.map(|idx| (idx, &self.columns[idx]))
|
|
} else {
|
|
// Unqualified lookup - just match by name
|
|
self.columns
|
|
.iter()
|
|
.position(|c| c.name == name)
|
|
.map(|idx| (idx, &self.columns[idx]))
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Logical representation of a SQL query plan
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub enum LogicalPlan {
|
|
/// Projection - SELECT expressions
|
|
Projection(Projection),
|
|
/// Filter - WHERE/HAVING clause
|
|
Filter(Filter),
|
|
/// Aggregate - GROUP BY with aggregate functions
|
|
Aggregate(Aggregate),
|
|
/// Join - combining two relations
|
|
Join(Join),
|
|
/// Sort - ORDER BY clause
|
|
Sort(Sort),
|
|
/// Limit - LIMIT/OFFSET clause
|
|
Limit(Limit),
|
|
/// Table scan - reading from a base table
|
|
TableScan(TableScan),
|
|
/// Union - UNION/UNION ALL/INTERSECT/EXCEPT
|
|
Union(Union),
|
|
/// Distinct - remove duplicates
|
|
Distinct(Distinct),
|
|
/// Empty relation - no rows
|
|
EmptyRelation(EmptyRelation),
|
|
/// Values - literal rows (VALUES clause)
|
|
Values(Values),
|
|
/// CTE support - WITH clause
|
|
WithCTE(WithCTE),
|
|
/// Reference to a CTE
|
|
CTERef(CTERef),
|
|
}
|
|
|
|
impl LogicalPlan {
|
|
/// Get the schema of this plan node
|
|
pub fn schema(&self) -> &SchemaRef {
|
|
match self {
|
|
LogicalPlan::Projection(p) => &p.schema,
|
|
LogicalPlan::Filter(f) => f.input.schema(),
|
|
LogicalPlan::Aggregate(a) => &a.schema,
|
|
LogicalPlan::Join(j) => &j.schema,
|
|
LogicalPlan::Sort(s) => s.input.schema(),
|
|
LogicalPlan::Limit(l) => l.input.schema(),
|
|
LogicalPlan::TableScan(t) => &t.schema,
|
|
LogicalPlan::Union(u) => &u.schema,
|
|
LogicalPlan::Distinct(d) => d.input.schema(),
|
|
LogicalPlan::EmptyRelation(e) => &e.schema,
|
|
LogicalPlan::Values(v) => &v.schema,
|
|
LogicalPlan::WithCTE(w) => w.body.schema(),
|
|
LogicalPlan::CTERef(c) => &c.schema,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Projection operator - SELECT expressions
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct Projection {
|
|
pub input: Arc<LogicalPlan>,
|
|
pub exprs: Vec<LogicalExpr>,
|
|
pub schema: SchemaRef,
|
|
}
|
|
|
|
/// Filter operator - WHERE/HAVING predicates
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct Filter {
|
|
pub input: Arc<LogicalPlan>,
|
|
pub predicate: LogicalExpr,
|
|
}
|
|
|
|
/// Aggregate operator - GROUP BY with aggregations
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct Aggregate {
|
|
pub input: Arc<LogicalPlan>,
|
|
pub group_expr: Vec<LogicalExpr>,
|
|
pub aggr_expr: Vec<LogicalExpr>,
|
|
pub schema: SchemaRef,
|
|
}
|
|
|
|
/// Types of joins
|
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
|
pub enum JoinType {
|
|
Inner,
|
|
Left,
|
|
Right,
|
|
Full,
|
|
Cross,
|
|
}
|
|
|
|
/// Join operator - combines two relations
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct Join {
|
|
pub left: Arc<LogicalPlan>,
|
|
pub right: Arc<LogicalPlan>,
|
|
pub join_type: JoinType,
|
|
pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions (left_expr, right_expr)
|
|
pub filter: Option<LogicalExpr>, // Additional filter conditions
|
|
pub schema: SchemaRef,
|
|
}
|
|
|
|
/// Sort operator - ORDER BY
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct Sort {
|
|
pub input: Arc<LogicalPlan>,
|
|
pub exprs: Vec<SortExpr>,
|
|
}
|
|
|
|
/// Sort expression with direction
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct SortExpr {
|
|
pub expr: LogicalExpr,
|
|
pub asc: bool,
|
|
pub nulls_first: bool,
|
|
}
|
|
|
|
/// Limit operator - LIMIT/OFFSET
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct Limit {
|
|
pub input: Arc<LogicalPlan>,
|
|
pub skip: Option<usize>,
|
|
pub fetch: Option<usize>,
|
|
}
|
|
|
|
/// Table scan operator
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct TableScan {
|
|
pub table_name: String,
|
|
pub alias: Option<String>,
|
|
pub schema: SchemaRef,
|
|
pub projection: Option<Vec<usize>>, // Column indices to project
|
|
}
|
|
|
|
/// Union operator
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct Union {
|
|
pub inputs: Vec<Arc<LogicalPlan>>,
|
|
pub all: bool, // true for UNION ALL, false for UNION
|
|
pub schema: SchemaRef,
|
|
}
|
|
|
|
/// Distinct operator
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct Distinct {
|
|
pub input: Arc<LogicalPlan>,
|
|
}
|
|
|
|
/// Empty relation - produces no rows
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct EmptyRelation {
|
|
pub produce_one_row: bool,
|
|
pub schema: SchemaRef,
|
|
}
|
|
|
|
/// Values operator - literal rows
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct Values {
|
|
pub rows: Vec<Vec<LogicalExpr>>,
|
|
pub schema: SchemaRef,
|
|
}
|
|
|
|
/// WITH clause - CTEs
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct WithCTE {
|
|
pub ctes: HashMap<String, Arc<LogicalPlan>>,
|
|
pub body: Arc<LogicalPlan>,
|
|
}
|
|
|
|
/// Reference to a CTE
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct CTERef {
|
|
pub name: String,
|
|
pub schema: SchemaRef,
|
|
}
|
|
|
|
/// Logical expression representation
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub enum LogicalExpr {
|
|
/// Column reference
|
|
Column(Column),
|
|
/// Literal value
|
|
Literal(Value),
|
|
/// Binary expression
|
|
BinaryExpr {
|
|
left: Box<LogicalExpr>,
|
|
op: BinaryOperator,
|
|
right: Box<LogicalExpr>,
|
|
},
|
|
/// Unary expression
|
|
UnaryExpr {
|
|
op: UnaryOperator,
|
|
expr: Box<LogicalExpr>,
|
|
},
|
|
/// Aggregate function
|
|
AggregateFunction {
|
|
fun: AggregateFunction,
|
|
args: Vec<LogicalExpr>,
|
|
distinct: bool,
|
|
},
|
|
/// Scalar function call
|
|
ScalarFunction { fun: String, args: Vec<LogicalExpr> },
|
|
/// CASE expression
|
|
Case {
|
|
expr: Option<Box<LogicalExpr>>,
|
|
when_then: Vec<(LogicalExpr, LogicalExpr)>,
|
|
else_expr: Option<Box<LogicalExpr>>,
|
|
},
|
|
/// IN list
|
|
InList {
|
|
expr: Box<LogicalExpr>,
|
|
list: Vec<LogicalExpr>,
|
|
negated: bool,
|
|
},
|
|
/// IN subquery
|
|
InSubquery {
|
|
expr: Box<LogicalExpr>,
|
|
subquery: Arc<LogicalPlan>,
|
|
negated: bool,
|
|
},
|
|
/// EXISTS subquery
|
|
Exists {
|
|
subquery: Arc<LogicalPlan>,
|
|
negated: bool,
|
|
},
|
|
/// Scalar subquery
|
|
ScalarSubquery(Arc<LogicalPlan>),
|
|
/// Alias for an expression
|
|
Alias {
|
|
expr: Box<LogicalExpr>,
|
|
alias: String,
|
|
},
|
|
/// IS NULL / IS NOT NULL
|
|
IsNull {
|
|
expr: Box<LogicalExpr>,
|
|
negated: bool,
|
|
},
|
|
/// BETWEEN
|
|
Between {
|
|
expr: Box<LogicalExpr>,
|
|
low: Box<LogicalExpr>,
|
|
high: Box<LogicalExpr>,
|
|
negated: bool,
|
|
},
|
|
/// LIKE pattern matching
|
|
Like {
|
|
expr: Box<LogicalExpr>,
|
|
pattern: Box<LogicalExpr>,
|
|
escape: Option<char>,
|
|
negated: bool,
|
|
},
|
|
/// CAST expression
|
|
Cast {
|
|
expr: Box<LogicalExpr>,
|
|
type_name: Option<ast::Type>,
|
|
},
|
|
}
|
|
|
|
/// Column reference
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct Column {
|
|
pub name: String,
|
|
pub table: Option<String>,
|
|
}
|
|
|
|
impl Column {
|
|
pub fn new(name: impl Into<String>) -> Self {
|
|
Self {
|
|
name: name.into(),
|
|
table: None,
|
|
}
|
|
}
|
|
|
|
pub fn with_table(name: impl Into<String>, table: impl Into<String>) -> Self {
|
|
Self {
|
|
name: name.into(),
|
|
table: Some(table.into()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Display for Column {
|
|
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
|
match &self.table {
|
|
Some(t) => write!(f, "{}.{}", t, self.name),
|
|
None => write!(f, "{}", self.name),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Type alias for binary operators
|
|
pub type BinaryOperator = ast::Operator;
|
|
|
|
/// Type alias for unary operators
|
|
pub type UnaryOperator = ast::UnaryOperator;
|
|
|
|
/// Type alias for aggregate functions
|
|
pub type AggregateFunction = AggFunc;
|
|
|
|
/// Compiler from AST to LogicalPlan
|
|
pub struct LogicalPlanBuilder<'a> {
|
|
schema: &'a Schema,
|
|
ctes: HashMap<String, Arc<LogicalPlan>>,
|
|
}
|
|
|
|
impl<'a> LogicalPlanBuilder<'a> {
|
|
pub fn new(schema: &'a Schema) -> Self {
|
|
Self {
|
|
schema,
|
|
ctes: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
/// Main entry point: compile a statement to a logical plan
|
|
pub fn build_statement(&mut self, stmt: &ast::Stmt) -> Result<LogicalPlan> {
|
|
match stmt {
|
|
ast::Stmt::Select(select) => self.build_select(select),
|
|
_ => Err(LimboError::ParseError(
|
|
"Only SELECT statements are currently supported in logical plans".to_string(),
|
|
)),
|
|
}
|
|
}
|
|
|
|
// Convert Name to String
|
|
fn name_to_string(name: &ast::Name) -> String {
|
|
match name {
|
|
ast::Name::Ident(s) | ast::Name::Quoted(s) => s.clone(),
|
|
}
|
|
}
|
|
|
|
// Build a SELECT statement
|
|
// Build a logical plan from a SELECT statement
|
|
fn build_select(&mut self, select: &ast::Select) -> Result<LogicalPlan> {
|
|
// Handle WITH clause if present
|
|
if let Some(with) = &select.with {
|
|
return self.build_with_cte(with, select);
|
|
}
|
|
|
|
// Build the main query body
|
|
let order_by = &select.order_by;
|
|
let limit = &select.limit;
|
|
self.build_select_body(&select.body, order_by, limit)
|
|
}
|
|
|
|
// Build WITH CTE
|
|
fn build_with_cte(&mut self, with: &ast::With, select: &ast::Select) -> Result<LogicalPlan> {
|
|
let mut cte_plans = HashMap::new();
|
|
|
|
// Build each CTE
|
|
for cte in &with.ctes {
|
|
let cte_plan = self.build_select(&cte.select)?;
|
|
let cte_name = Self::name_to_string(&cte.tbl_name);
|
|
cte_plans.insert(cte_name.clone(), Arc::new(cte_plan));
|
|
self.ctes
|
|
.insert(cte_name.clone(), cte_plans[&cte_name].clone());
|
|
}
|
|
|
|
// Build the main body with CTEs available
|
|
let order_by = &select.order_by;
|
|
let limit = &select.limit;
|
|
let body = self.build_select_body(&select.body, order_by, limit)?;
|
|
|
|
// Clear CTEs from builder context
|
|
for cte in &with.ctes {
|
|
self.ctes.remove(&Self::name_to_string(&cte.tbl_name));
|
|
}
|
|
|
|
Ok(LogicalPlan::WithCTE(WithCTE {
|
|
ctes: cte_plans,
|
|
body: Arc::new(body),
|
|
}))
|
|
}
|
|
|
|
// Build SELECT body
|
|
fn build_select_body(
|
|
&mut self,
|
|
body: &ast::SelectBody,
|
|
order_by: &[ast::SortedColumn],
|
|
limit: &Option<ast::Limit>,
|
|
) -> Result<LogicalPlan> {
|
|
let mut plan = self.build_one_select(&body.select)?;
|
|
|
|
// Handle compound operators (UNION, INTERSECT, EXCEPT)
|
|
if !body.compounds.is_empty() {
|
|
for compound in &body.compounds {
|
|
let right = self.build_one_select(&compound.select)?;
|
|
plan = Self::build_compound(plan, right, &compound.operator)?;
|
|
}
|
|
}
|
|
|
|
// Apply ORDER BY
|
|
if !order_by.is_empty() {
|
|
plan = self.build_sort(plan, order_by)?;
|
|
}
|
|
|
|
// Apply LIMIT
|
|
if let Some(limit) = limit {
|
|
plan = Self::build_limit(plan, limit)?;
|
|
}
|
|
|
|
Ok(plan)
|
|
}
|
|
|
|
// Build a single SELECT (without compounds)
|
|
fn build_one_select(&mut self, select: &ast::OneSelect) -> Result<LogicalPlan> {
|
|
match select {
|
|
ast::OneSelect::Select {
|
|
distinctness,
|
|
columns,
|
|
from,
|
|
where_clause,
|
|
group_by,
|
|
window_clause: _,
|
|
} => {
|
|
// Start with FROM clause
|
|
let mut plan = if let Some(from) = from {
|
|
self.build_from(from)?
|
|
} else {
|
|
// No FROM clause - single row
|
|
LogicalPlan::EmptyRelation(EmptyRelation {
|
|
produce_one_row: true,
|
|
schema: Arc::new(LogicalSchema::empty()),
|
|
})
|
|
};
|
|
|
|
// Apply WHERE
|
|
if let Some(where_expr) = where_clause {
|
|
let predicate = self.build_expr(where_expr, plan.schema())?;
|
|
plan = LogicalPlan::Filter(Filter {
|
|
input: Arc::new(plan),
|
|
predicate,
|
|
});
|
|
}
|
|
|
|
// Apply GROUP BY and aggregations
|
|
if let Some(group_by) = group_by {
|
|
plan = self.build_aggregate(plan, group_by, columns)?;
|
|
} else if Self::has_aggregates(columns) {
|
|
// Aggregation without GROUP BY
|
|
plan = self.build_aggregate_no_group(plan, columns)?;
|
|
} else {
|
|
// Regular projection
|
|
plan = self.build_projection(plan, columns)?;
|
|
}
|
|
|
|
// Apply HAVING (part of GROUP BY)
|
|
if let Some(ref group_by) = group_by {
|
|
if let Some(ref having_expr) = group_by.having {
|
|
let predicate = self.build_expr(having_expr, plan.schema())?;
|
|
plan = LogicalPlan::Filter(Filter {
|
|
input: Arc::new(plan),
|
|
predicate,
|
|
});
|
|
}
|
|
}
|
|
|
|
// Apply DISTINCT
|
|
if distinctness.is_some() {
|
|
plan = LogicalPlan::Distinct(Distinct {
|
|
input: Arc::new(plan),
|
|
});
|
|
}
|
|
|
|
Ok(plan)
|
|
}
|
|
ast::OneSelect::Values(values) => self.build_values(values),
|
|
}
|
|
}
|
|
|
|
// Build FROM clause
|
|
fn build_from(&mut self, from: &ast::FromClause) -> Result<LogicalPlan> {
|
|
let mut plan = { self.build_select_table(&from.select)? };
|
|
|
|
// Handle JOINs
|
|
if !from.joins.is_empty() {
|
|
for join in &from.joins {
|
|
let right = self.build_select_table(&join.table)?;
|
|
plan = self.build_join(plan, right, &join.operator, &join.constraint)?;
|
|
}
|
|
}
|
|
|
|
Ok(plan)
|
|
}
|
|
|
|
// Build a table reference
|
|
fn build_select_table(&mut self, table: &ast::SelectTable) -> Result<LogicalPlan> {
|
|
match table {
|
|
ast::SelectTable::Table(name, alias, _indexed) => {
|
|
let table_name = Self::name_to_string(&name.name);
|
|
// Check if it's a CTE reference
|
|
if let Some(cte_plan) = self.ctes.get(&table_name) {
|
|
return Ok(LogicalPlan::CTERef(CTERef {
|
|
name: table_name.clone(),
|
|
schema: cte_plan.schema().clone(),
|
|
}));
|
|
}
|
|
|
|
// Regular table scan
|
|
let table_alias = alias.as_ref().map(|a| match a {
|
|
ast::As::As(name) => Self::name_to_string(name),
|
|
ast::As::Elided(name) => Self::name_to_string(name),
|
|
});
|
|
let table_schema = self.get_table_schema(&table_name, table_alias.as_deref())?;
|
|
Ok(LogicalPlan::TableScan(TableScan {
|
|
table_name,
|
|
alias: table_alias.clone(),
|
|
schema: table_schema,
|
|
projection: None,
|
|
}))
|
|
}
|
|
ast::SelectTable::Select(subquery, _alias) => self.build_select(subquery),
|
|
ast::SelectTable::TableCall(_, _, _) => Err(LimboError::ParseError(
|
|
"Table-valued functions are not supported in logical plans".to_string(),
|
|
)),
|
|
ast::SelectTable::Sub(_, _) => Err(LimboError::ParseError(
|
|
"Subquery in FROM clause not yet supported".to_string(),
|
|
)),
|
|
}
|
|
}
|
|
|
|
// Build JOIN
|
|
fn build_join(
|
|
&mut self,
|
|
left: LogicalPlan,
|
|
right: LogicalPlan,
|
|
op: &ast::JoinOperator,
|
|
constraint: &Option<ast::JoinConstraint>,
|
|
) -> Result<LogicalPlan> {
|
|
// Determine join type
|
|
let join_type = match op {
|
|
ast::JoinOperator::Comma => JoinType::Cross, // Comma is essentially a cross join
|
|
ast::JoinOperator::TypedJoin(Some(jt)) => {
|
|
// Check the join type flags
|
|
// Note: JoinType can have multiple flags set
|
|
if jt.contains(ast::JoinType::NATURAL) {
|
|
// Natural joins need special handling - find common columns
|
|
return self.build_natural_join(left, right, JoinType::Inner);
|
|
} else if jt.contains(ast::JoinType::LEFT)
|
|
&& jt.contains(ast::JoinType::RIGHT)
|
|
&& jt.contains(ast::JoinType::OUTER)
|
|
{
|
|
// FULL OUTER JOIN (has LEFT, RIGHT, and OUTER)
|
|
JoinType::Full
|
|
} else if jt.contains(ast::JoinType::LEFT) && jt.contains(ast::JoinType::OUTER) {
|
|
JoinType::Left
|
|
} else if jt.contains(ast::JoinType::RIGHT) && jt.contains(ast::JoinType::OUTER) {
|
|
JoinType::Right
|
|
} else if jt.contains(ast::JoinType::OUTER)
|
|
&& !jt.contains(ast::JoinType::LEFT)
|
|
&& !jt.contains(ast::JoinType::RIGHT)
|
|
{
|
|
// Plain OUTER JOIN should also be FULL
|
|
JoinType::Full
|
|
} else if jt.contains(ast::JoinType::LEFT) {
|
|
JoinType::Left
|
|
} else if jt.contains(ast::JoinType::RIGHT) {
|
|
JoinType::Right
|
|
} else if jt.contains(ast::JoinType::CROSS)
|
|
|| (jt.contains(ast::JoinType::INNER) && jt.contains(ast::JoinType::CROSS))
|
|
{
|
|
JoinType::Cross
|
|
} else {
|
|
JoinType::Inner // Default to inner
|
|
}
|
|
}
|
|
ast::JoinOperator::TypedJoin(None) => JoinType::Inner, // Default JOIN is INNER JOIN
|
|
};
|
|
|
|
// Build join conditions
|
|
let (on_conditions, filter) = match constraint {
|
|
Some(ast::JoinConstraint::On(expr)) => {
|
|
// Parse ON clause into equijoin conditions and filters
|
|
self.parse_join_conditions(expr, left.schema(), right.schema())?
|
|
}
|
|
Some(ast::JoinConstraint::Using(columns)) => {
|
|
// Build equijoin conditions from USING clause
|
|
let on = self.build_using_conditions(columns, left.schema(), right.schema())?;
|
|
(on, None)
|
|
}
|
|
None => {
|
|
// Cross join or natural join
|
|
(Vec::new(), None)
|
|
}
|
|
};
|
|
|
|
// Build combined schema
|
|
let schema = self.build_join_schema(&left, &right, &join_type)?;
|
|
|
|
Ok(LogicalPlan::Join(Join {
|
|
left: Arc::new(left),
|
|
right: Arc::new(right),
|
|
join_type,
|
|
on: on_conditions,
|
|
filter,
|
|
schema,
|
|
}))
|
|
}
|
|
|
|
// Helper: Parse join conditions into equijoins and filters
|
|
fn parse_join_conditions(
|
|
&mut self,
|
|
expr: &ast::Expr,
|
|
left_schema: &SchemaRef,
|
|
right_schema: &SchemaRef,
|
|
) -> Result<JoinConditionsResult> {
|
|
// For now, we'll handle simple equality conditions
|
|
// More complex conditions will go into the filter
|
|
let mut equijoins = Vec::new();
|
|
let mut filters = Vec::new();
|
|
|
|
// Try to extract equijoin conditions from the expression
|
|
self.extract_equijoin_conditions(
|
|
expr,
|
|
left_schema,
|
|
right_schema,
|
|
&mut equijoins,
|
|
&mut filters,
|
|
)?;
|
|
|
|
let filter = if filters.is_empty() {
|
|
None
|
|
} else {
|
|
// Combine multiple filters with AND
|
|
Some(
|
|
filters
|
|
.into_iter()
|
|
.reduce(|acc, e| LogicalExpr::BinaryExpr {
|
|
left: Box::new(acc),
|
|
op: BinaryOperator::And,
|
|
right: Box::new(e),
|
|
})
|
|
.unwrap(),
|
|
)
|
|
};
|
|
|
|
Ok((equijoins, filter))
|
|
}
|
|
|
|
// Helper: Extract equijoin conditions from expression
|
|
fn extract_equijoin_conditions(
|
|
&mut self,
|
|
expr: &ast::Expr,
|
|
left_schema: &SchemaRef,
|
|
right_schema: &SchemaRef,
|
|
equijoins: &mut Vec<(LogicalExpr, LogicalExpr)>,
|
|
filters: &mut Vec<LogicalExpr>,
|
|
) -> Result<()> {
|
|
match expr {
|
|
ast::Expr::Binary(lhs, ast::Operator::Equals, rhs) => {
|
|
// Check if this is an equijoin condition (left.col = right.col)
|
|
let left_expr = self.build_expr(lhs, left_schema)?;
|
|
let right_expr = self.build_expr(rhs, right_schema)?;
|
|
|
|
// For simplicity, we'll check if one references left and one references right
|
|
// In a real implementation, we'd need more sophisticated column resolution
|
|
equijoins.push((left_expr, right_expr));
|
|
}
|
|
ast::Expr::Binary(lhs, ast::Operator::And, rhs) => {
|
|
// Recursively extract from AND conditions
|
|
self.extract_equijoin_conditions(
|
|
lhs,
|
|
left_schema,
|
|
right_schema,
|
|
equijoins,
|
|
filters,
|
|
)?;
|
|
self.extract_equijoin_conditions(
|
|
rhs,
|
|
left_schema,
|
|
right_schema,
|
|
equijoins,
|
|
filters,
|
|
)?;
|
|
}
|
|
_ => {
|
|
// Other conditions go into the filter
|
|
// We need a combined schema to build the expression
|
|
let combined_schema = self.combine_schemas(left_schema, right_schema)?;
|
|
let filter_expr = self.build_expr(expr, &combined_schema)?;
|
|
filters.push(filter_expr);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
// Helper: Build equijoin conditions from USING clause
|
|
fn build_using_conditions(
|
|
&mut self,
|
|
columns: &[ast::Name],
|
|
left_schema: &SchemaRef,
|
|
right_schema: &SchemaRef,
|
|
) -> Result<Vec<(LogicalExpr, LogicalExpr)>> {
|
|
let mut conditions = Vec::new();
|
|
|
|
for col_name in columns {
|
|
let name = Self::name_to_string(col_name);
|
|
|
|
// Find the column in both schemas
|
|
let _left_idx = left_schema
|
|
.columns
|
|
.iter()
|
|
.position(|col| col.name == name)
|
|
.ok_or_else(|| {
|
|
LimboError::ParseError(format!("Column {name} not found in left table"))
|
|
})?;
|
|
let _right_idx = right_schema
|
|
.columns
|
|
.iter()
|
|
.position(|col| col.name == name)
|
|
.ok_or_else(|| {
|
|
LimboError::ParseError(format!("Column {name} not found in right table"))
|
|
})?;
|
|
|
|
conditions.push((
|
|
LogicalExpr::Column(Column {
|
|
name: name.clone(),
|
|
table: None, // Will be resolved later
|
|
}),
|
|
LogicalExpr::Column(Column {
|
|
name,
|
|
table: None, // Will be resolved later
|
|
}),
|
|
));
|
|
}
|
|
|
|
Ok(conditions)
|
|
}
|
|
|
|
// Helper: Build natural join by finding common columns
|
|
fn build_natural_join(
|
|
&mut self,
|
|
left: LogicalPlan,
|
|
right: LogicalPlan,
|
|
join_type: JoinType,
|
|
) -> Result<LogicalPlan> {
|
|
let left_schema = left.schema();
|
|
let right_schema = right.schema();
|
|
|
|
// Find common column names
|
|
let mut common_columns = Vec::new();
|
|
for left_col in &left_schema.columns {
|
|
if right_schema
|
|
.columns
|
|
.iter()
|
|
.any(|col| col.name == left_col.name)
|
|
{
|
|
common_columns.push(ast::Name::Ident(left_col.name.clone()));
|
|
}
|
|
}
|
|
|
|
if common_columns.is_empty() {
|
|
// Natural join with no common columns becomes a cross join
|
|
let schema = self.build_join_schema(&left, &right, &JoinType::Cross)?;
|
|
return Ok(LogicalPlan::Join(Join {
|
|
left: Arc::new(left),
|
|
right: Arc::new(right),
|
|
join_type: JoinType::Cross,
|
|
on: Vec::new(),
|
|
filter: None,
|
|
schema,
|
|
}));
|
|
}
|
|
|
|
// Build equijoin conditions for common columns
|
|
let on = self.build_using_conditions(&common_columns, left_schema, right_schema)?;
|
|
let schema = self.build_join_schema(&left, &right, &join_type)?;
|
|
|
|
Ok(LogicalPlan::Join(Join {
|
|
left: Arc::new(left),
|
|
right: Arc::new(right),
|
|
join_type,
|
|
on,
|
|
filter: None,
|
|
schema,
|
|
}))
|
|
}
|
|
|
|
// Helper: Build schema for join result
|
|
fn build_join_schema(
|
|
&self,
|
|
left: &LogicalPlan,
|
|
right: &LogicalPlan,
|
|
_join_type: &JoinType,
|
|
) -> Result<SchemaRef> {
|
|
let left_schema = left.schema();
|
|
let right_schema = right.schema();
|
|
|
|
// Concatenate the schemas, preserving all column information
|
|
let mut columns = Vec::new();
|
|
|
|
// Keep all columns from left with their table info
|
|
for col in &left_schema.columns {
|
|
columns.push(col.clone());
|
|
}
|
|
|
|
// Keep all columns from right with their table info
|
|
for col in &right_schema.columns {
|
|
columns.push(col.clone());
|
|
}
|
|
|
|
Ok(Arc::new(LogicalSchema::new(columns)))
|
|
}
|
|
|
|
// Helper: Combine two schemas for expression building
|
|
fn combine_schemas(&self, left: &SchemaRef, right: &SchemaRef) -> Result<SchemaRef> {
|
|
let mut columns = left.columns.clone();
|
|
columns.extend(right.columns.clone());
|
|
Ok(Arc::new(LogicalSchema::new(columns)))
|
|
}
|
|
|
|
// Build projection
|
|
fn build_projection(
|
|
&mut self,
|
|
input: LogicalPlan,
|
|
columns: &[ast::ResultColumn],
|
|
) -> Result<LogicalPlan> {
|
|
let input_schema = input.schema();
|
|
let mut proj_exprs = Vec::new();
|
|
let mut schema_columns = Vec::new();
|
|
|
|
for col in columns {
|
|
match col {
|
|
ast::ResultColumn::Expr(expr, alias) => {
|
|
let logical_expr = self.build_expr(expr, input_schema)?;
|
|
let col_name = match alias {
|
|
Some(as_alias) => match as_alias {
|
|
ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name),
|
|
},
|
|
None => Self::expr_to_column_name(expr),
|
|
};
|
|
let col_type = Self::infer_expr_type(&logical_expr, input_schema)?;
|
|
|
|
schema_columns.push(ColumnInfo {
|
|
name: col_name.clone(),
|
|
ty: col_type,
|
|
database: None,
|
|
table: None,
|
|
table_alias: None,
|
|
});
|
|
|
|
if let Some(as_alias) = alias {
|
|
let alias_name = match as_alias {
|
|
ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name),
|
|
};
|
|
proj_exprs.push(LogicalExpr::Alias {
|
|
expr: Box::new(logical_expr),
|
|
alias: alias_name,
|
|
});
|
|
} else {
|
|
proj_exprs.push(logical_expr);
|
|
}
|
|
}
|
|
ast::ResultColumn::Star => {
|
|
// Expand * to all columns
|
|
for col in &input_schema.columns {
|
|
proj_exprs.push(LogicalExpr::Column(Column::new(col.name.clone())));
|
|
schema_columns.push(col.clone());
|
|
}
|
|
}
|
|
ast::ResultColumn::TableStar(table) => {
|
|
// Expand table.* to all columns from that table
|
|
let table_name = Self::name_to_string(table);
|
|
for col in &input_schema.columns {
|
|
// Simple check - would need proper table tracking in real implementation
|
|
proj_exprs.push(LogicalExpr::Column(Column::with_table(
|
|
col.name.clone(),
|
|
table_name.clone(),
|
|
)));
|
|
schema_columns.push(col.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(LogicalPlan::Projection(Projection {
|
|
input: Arc::new(input),
|
|
exprs: proj_exprs,
|
|
schema: Arc::new(LogicalSchema::new(schema_columns)),
|
|
}))
|
|
}
|
|
|
|
// Helper function to preprocess aggregate expressions that contain complex arguments
|
|
// Returns: (needs_pre_projection, pre_projection_exprs, pre_projection_schema, modified_aggr_exprs)
|
|
//
|
|
// This will be used in expressions like select sum(hex(a + 2)) from tbl => hex(a + 2) is a
|
|
// pre-projection.
|
|
//
|
|
// Another alternative is to always generate a projection together with an aggregation, and
|
|
// just have "a" be the identity projection if we don't have a complex case. But that's quite
|
|
// wasteful.
|
|
fn preprocess_aggregate_expressions(
|
|
aggr_exprs: &[LogicalExpr],
|
|
group_exprs: &[LogicalExpr],
|
|
input_schema: &SchemaRef,
|
|
) -> Result<PreprocessAggregateResult> {
|
|
let mut needs_pre_projection = false;
|
|
let mut pre_projection_exprs = Vec::new();
|
|
let mut pre_projection_schema = Vec::new();
|
|
let mut modified_aggr_exprs = Vec::new();
|
|
let mut projected_col_counter = 0;
|
|
|
|
// First, add all group by expressions to the pre-projection
|
|
for expr in group_exprs {
|
|
if let LogicalExpr::Column(col) = expr {
|
|
pre_projection_exprs.push(expr.clone());
|
|
let col_type = Self::infer_expr_type(expr, input_schema)?;
|
|
pre_projection_schema.push(ColumnInfo {
|
|
name: col.name.clone(),
|
|
ty: col_type,
|
|
database: None,
|
|
table: col.table.clone(),
|
|
table_alias: None,
|
|
});
|
|
} else {
|
|
// Complex group by expression - project it
|
|
needs_pre_projection = true;
|
|
let proj_col_name = format!("__group_proj_{projected_col_counter}");
|
|
projected_col_counter += 1;
|
|
pre_projection_exprs.push(expr.clone());
|
|
let col_type = Self::infer_expr_type(expr, input_schema)?;
|
|
pre_projection_schema.push(ColumnInfo {
|
|
name: proj_col_name.clone(),
|
|
ty: col_type,
|
|
database: None,
|
|
table: None,
|
|
table_alias: None,
|
|
});
|
|
}
|
|
}
|
|
|
|
// Check each aggregate expression
|
|
for agg_expr in aggr_exprs {
|
|
if let LogicalExpr::AggregateFunction {
|
|
fun,
|
|
args,
|
|
distinct,
|
|
} = agg_expr
|
|
{
|
|
let mut modified_args = Vec::new();
|
|
for arg in args {
|
|
// Check if the argument is a simple column reference or a complex expression
|
|
match arg {
|
|
LogicalExpr::Column(_) => {
|
|
// Simple column - just use it
|
|
modified_args.push(arg.clone());
|
|
// Make sure the column is in the pre-projection
|
|
if !pre_projection_exprs.iter().any(|e| e == arg) {
|
|
pre_projection_exprs.push(arg.clone());
|
|
let col_type = Self::infer_expr_type(arg, input_schema)?;
|
|
if let LogicalExpr::Column(col) = arg {
|
|
pre_projection_schema.push(ColumnInfo {
|
|
name: col.name.clone(),
|
|
ty: col_type,
|
|
database: None,
|
|
table: col.table.clone(),
|
|
table_alias: None,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
_ => {
|
|
// Complex expression - we need to project it first
|
|
needs_pre_projection = true;
|
|
let proj_col_name = format!("__agg_arg_proj_{projected_col_counter}");
|
|
projected_col_counter += 1;
|
|
|
|
// Add the expression to the pre-projection
|
|
pre_projection_exprs.push(arg.clone());
|
|
let col_type = Self::infer_expr_type(arg, input_schema)?;
|
|
pre_projection_schema.push(ColumnInfo {
|
|
name: proj_col_name.clone(),
|
|
ty: col_type,
|
|
database: None,
|
|
table: None,
|
|
table_alias: None,
|
|
});
|
|
|
|
// In the aggregate, reference the projected column
|
|
modified_args.push(LogicalExpr::Column(Column::new(proj_col_name)));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create the modified aggregate expression
|
|
modified_aggr_exprs.push(LogicalExpr::AggregateFunction {
|
|
fun: fun.clone(),
|
|
args: modified_args,
|
|
distinct: *distinct,
|
|
});
|
|
} else {
|
|
modified_aggr_exprs.push(agg_expr.clone());
|
|
}
|
|
}
|
|
|
|
Ok((
|
|
needs_pre_projection,
|
|
pre_projection_exprs,
|
|
pre_projection_schema,
|
|
modified_aggr_exprs,
|
|
))
|
|
}
|
|
|
|
// Build aggregate with GROUP BY
|
|
fn build_aggregate(
|
|
&mut self,
|
|
input: LogicalPlan,
|
|
group_by: &ast::GroupBy,
|
|
columns: &[ast::ResultColumn],
|
|
) -> Result<LogicalPlan> {
|
|
let input_schema = input.schema();
|
|
|
|
// Build grouping expressions
|
|
let mut group_exprs = Vec::new();
|
|
for expr in &group_by.exprs {
|
|
group_exprs.push(self.build_expr(expr, input_schema)?);
|
|
}
|
|
|
|
// Use the unified aggregate builder
|
|
self.build_aggregate_internal(input, group_exprs, columns)
|
|
}
|
|
|
|
// Build aggregate without GROUP BY
|
|
fn build_aggregate_no_group(
|
|
&mut self,
|
|
input: LogicalPlan,
|
|
columns: &[ast::ResultColumn],
|
|
) -> Result<LogicalPlan> {
|
|
// Use the unified aggregate builder with empty group expressions
|
|
self.build_aggregate_internal(input, vec![], columns)
|
|
}
|
|
|
|
// Unified internal aggregate builder that handles both GROUP BY and non-GROUP BY cases
|
|
fn build_aggregate_internal(
|
|
&mut self,
|
|
input: LogicalPlan,
|
|
group_exprs: Vec<LogicalExpr>,
|
|
columns: &[ast::ResultColumn],
|
|
) -> Result<LogicalPlan> {
|
|
let input_schema = input.schema();
|
|
let has_group_by = !group_exprs.is_empty();
|
|
|
|
// Build aggregate expressions and projection expressions
|
|
let mut aggr_exprs = Vec::new();
|
|
let mut projection_exprs = Vec::new();
|
|
let mut aggregate_schema_columns = Vec::new();
|
|
|
|
// First, add GROUP BY columns to the aggregate output schema
|
|
// These are always part of the aggregate operator's output
|
|
for group_expr in &group_exprs {
|
|
match group_expr {
|
|
LogicalExpr::Column(col) => {
|
|
// For column references in GROUP BY, preserve the original column info
|
|
if let Some((_, col_info)) =
|
|
input_schema.find_column(&col.name, col.table.as_deref())
|
|
{
|
|
// Preserve the column with all its table information
|
|
aggregate_schema_columns.push(col_info.clone());
|
|
} else {
|
|
// Fallback if column not found (shouldn't happen)
|
|
let col_type = Self::infer_expr_type(group_expr, input_schema)?;
|
|
aggregate_schema_columns.push(ColumnInfo {
|
|
name: col.name.clone(),
|
|
ty: col_type,
|
|
database: None,
|
|
table: col.table.clone(),
|
|
table_alias: None,
|
|
});
|
|
}
|
|
}
|
|
_ => {
|
|
// For complex GROUP BY expressions, generate a name
|
|
let col_name = format!("__group_{}", aggregate_schema_columns.len());
|
|
let col_type = Self::infer_expr_type(group_expr, input_schema)?;
|
|
aggregate_schema_columns.push(ColumnInfo {
|
|
name: col_name,
|
|
ty: col_type,
|
|
database: None,
|
|
table: None,
|
|
table_alias: None,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// Track aggregates we've already seen to avoid duplicates
|
|
let mut aggregate_map: std::collections::HashMap<String, String> =
|
|
std::collections::HashMap::new();
|
|
|
|
for col in columns {
|
|
match col {
|
|
ast::ResultColumn::Expr(expr, alias) => {
|
|
let logical_expr = self.build_expr(expr, input_schema)?;
|
|
|
|
// Determine the column name for this expression
|
|
let col_name = match alias {
|
|
Some(as_alias) => match as_alias {
|
|
ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name),
|
|
},
|
|
None => Self::expr_to_column_name(expr),
|
|
};
|
|
|
|
// Check if the TOP-LEVEL expression is an aggregate
|
|
// We only care about immediate aggregates, not nested ones
|
|
if Self::is_aggregate_expr(&logical_expr) {
|
|
// Pure aggregate function - check if we've seen it before
|
|
let agg_key = format!("{logical_expr:?}");
|
|
|
|
let agg_col_name = if let Some(existing_name) = aggregate_map.get(&agg_key)
|
|
{
|
|
// Reuse existing aggregate
|
|
existing_name.clone()
|
|
} else {
|
|
// New aggregate - add it
|
|
let col_type = Self::infer_expr_type(&logical_expr, input_schema)?;
|
|
aggregate_schema_columns.push(ColumnInfo {
|
|
name: col_name.clone(),
|
|
ty: col_type,
|
|
database: None,
|
|
table: None,
|
|
table_alias: None,
|
|
});
|
|
aggr_exprs.push(logical_expr);
|
|
aggregate_map.insert(agg_key, col_name.clone());
|
|
col_name.clone()
|
|
};
|
|
|
|
// In the projection, reference this aggregate by name
|
|
projection_exprs.push(LogicalExpr::Column(Column {
|
|
name: agg_col_name,
|
|
table: None,
|
|
}));
|
|
} else if Self::contains_aggregate(&logical_expr) {
|
|
// This is an expression that contains an aggregate somewhere
|
|
// (e.g., sum(a + 2) * 2)
|
|
// We need to extract aggregates and replace them with column references
|
|
let (processed_expr, extracted_aggs) =
|
|
Self::extract_and_replace_aggregates_with_dedup(
|
|
logical_expr,
|
|
&mut aggregate_map,
|
|
)?;
|
|
|
|
// Add only new aggregates
|
|
for (agg_expr, agg_name) in extracted_aggs {
|
|
let agg_type = Self::infer_expr_type(&agg_expr, input_schema)?;
|
|
aggregate_schema_columns.push(ColumnInfo {
|
|
name: agg_name,
|
|
ty: agg_type,
|
|
database: None,
|
|
table: None,
|
|
table_alias: None,
|
|
});
|
|
aggr_exprs.push(agg_expr);
|
|
}
|
|
|
|
// Add the processed expression (with column refs) to projection
|
|
projection_exprs.push(processed_expr);
|
|
} else {
|
|
// Non-aggregate expression - validation depends on GROUP BY presence
|
|
if has_group_by {
|
|
// With GROUP BY: only allow constants and grouped columns
|
|
// TODO: SQLite actually allows any column here and returns the value from
|
|
// the first row encountered in each group. We should support this in the
|
|
// future for full SQLite compatibility, but for now we're stricter to
|
|
// simplify the DBSP compilation.
|
|
if !Self::is_constant_expr(&logical_expr)
|
|
&& !Self::is_valid_in_group_by(&logical_expr, &group_exprs)
|
|
{
|
|
return Err(LimboError::ParseError(format!(
|
|
"Column '{col_name}' must appear in the GROUP BY clause or be used in an aggregate function"
|
|
)));
|
|
}
|
|
} else {
|
|
// Without GROUP BY: only allow constant expressions
|
|
// TODO: SQLite allows any column here and returns a value from an
|
|
// arbitrary row. We should support this for full compatibility,
|
|
// but for now we're stricter to simplify DBSP compilation.
|
|
if !Self::is_constant_expr(&logical_expr) {
|
|
return Err(LimboError::ParseError(format!(
|
|
"Column '{col_name}' must be used in an aggregate function when using aggregates without GROUP BY"
|
|
)));
|
|
}
|
|
}
|
|
projection_exprs.push(logical_expr);
|
|
}
|
|
}
|
|
_ => {
|
|
let error_msg = if has_group_by {
|
|
"* not supported with GROUP BY".to_string()
|
|
} else {
|
|
"* not supported with aggregate functions".to_string()
|
|
};
|
|
return Err(LimboError::ParseError(error_msg));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check if any aggregate functions have complex expressions as arguments
|
|
// If so, we need to insert a projection before the aggregate
|
|
let (
|
|
needs_pre_projection,
|
|
pre_projection_exprs,
|
|
pre_projection_schema,
|
|
modified_aggr_exprs,
|
|
) = Self::preprocess_aggregate_expressions(&aggr_exprs, &group_exprs, input_schema)?;
|
|
|
|
// Build the final schema for the projection
|
|
let mut projection_schema_columns = Vec::new();
|
|
for (i, expr) in projection_exprs.iter().enumerate() {
|
|
let col_name = if i < columns.len() {
|
|
match &columns[i] {
|
|
ast::ResultColumn::Expr(e, alias) => match alias {
|
|
Some(as_alias) => match as_alias {
|
|
ast::As::As(name) | ast::As::Elided(name) => Self::name_to_string(name),
|
|
},
|
|
None => Self::expr_to_column_name(e),
|
|
},
|
|
_ => format!("col_{i}"),
|
|
}
|
|
} else {
|
|
format!("col_{i}")
|
|
};
|
|
|
|
// For type inference, we need the aggregate schema for column references
|
|
let aggregate_schema = LogicalSchema::new(aggregate_schema_columns.clone());
|
|
let col_type = Self::infer_expr_type(expr, &Arc::new(aggregate_schema))?;
|
|
projection_schema_columns.push(ColumnInfo {
|
|
name: col_name,
|
|
ty: col_type,
|
|
database: None,
|
|
table: None,
|
|
table_alias: None,
|
|
});
|
|
}
|
|
|
|
// Create the input plan (with pre-projection if needed)
|
|
let aggregate_input = if needs_pre_projection {
|
|
Arc::new(LogicalPlan::Projection(Projection {
|
|
input: Arc::new(input),
|
|
exprs: pre_projection_exprs,
|
|
schema: Arc::new(LogicalSchema::new(pre_projection_schema)),
|
|
}))
|
|
} else {
|
|
Arc::new(input)
|
|
};
|
|
|
|
// Use modified aggregate expressions if we inserted a pre-projection
|
|
let final_aggr_exprs = if needs_pre_projection {
|
|
modified_aggr_exprs
|
|
} else {
|
|
aggr_exprs
|
|
};
|
|
|
|
// Check if we need the outer projection
|
|
// We need a projection if:
|
|
// 1. We have expressions that compute new values (e.g., SUM(x) * 2)
|
|
// 2. We're selecting a different set of columns than GROUP BY + aggregates
|
|
// 3. We're reordering columns from their natural aggregate output order
|
|
let needs_outer_projection = {
|
|
// Check for complex expressions
|
|
let has_complex_exprs = projection_exprs
|
|
.iter()
|
|
.any(|expr| !matches!(expr, LogicalExpr::Column(_)));
|
|
|
|
if has_complex_exprs {
|
|
true
|
|
} else {
|
|
// Check if we're selecting exactly what aggregate outputs in the same order
|
|
// The aggregate outputs: all GROUP BY columns, then all aggregate expressions
|
|
// The projection might select a subset or reorder these
|
|
|
|
if projection_exprs.len() != aggregate_schema_columns.len() {
|
|
// Different number of columns
|
|
true
|
|
} else {
|
|
// Check if columns match in order and name
|
|
!projection_exprs.iter().zip(&aggregate_schema_columns).all(
|
|
|(expr, agg_col)| {
|
|
if let LogicalExpr::Column(col) = expr {
|
|
col.name == agg_col.name
|
|
} else {
|
|
false
|
|
}
|
|
},
|
|
)
|
|
}
|
|
}
|
|
};
|
|
|
|
// Create the aggregate node with its natural schema
|
|
let aggregate_plan = LogicalPlan::Aggregate(Aggregate {
|
|
input: aggregate_input,
|
|
group_expr: group_exprs,
|
|
aggr_expr: final_aggr_exprs,
|
|
schema: Arc::new(LogicalSchema::new(aggregate_schema_columns)),
|
|
});
|
|
|
|
if needs_outer_projection {
|
|
Ok(LogicalPlan::Projection(Projection {
|
|
input: Arc::new(aggregate_plan),
|
|
exprs: projection_exprs,
|
|
schema: Arc::new(LogicalSchema::new(projection_schema_columns)),
|
|
}))
|
|
} else {
|
|
// No projection needed - aggregate output matches what we want
|
|
Ok(aggregate_plan)
|
|
}
|
|
}
|
|
|
|
/// Build VALUES clause
|
|
#[allow(clippy::vec_box)]
|
|
fn build_values(&mut self, values: &[Vec<Box<ast::Expr>>]) -> Result<LogicalPlan> {
|
|
if values.is_empty() {
|
|
return Err(LimboError::ParseError("Empty VALUES clause".to_string()));
|
|
}
|
|
|
|
let mut rows = Vec::new();
|
|
let first_row_len = values[0].len();
|
|
|
|
// Infer schema from first row
|
|
let mut schema_columns = Vec::new();
|
|
for (i, _) in values[0].iter().enumerate() {
|
|
schema_columns.push(ColumnInfo {
|
|
name: format!("column{}", i + 1),
|
|
ty: Type::Text,
|
|
database: None,
|
|
table: None,
|
|
table_alias: None,
|
|
});
|
|
}
|
|
|
|
for row in values {
|
|
if row.len() != first_row_len {
|
|
return Err(LimboError::ParseError(
|
|
"All rows in VALUES must have the same number of columns".to_string(),
|
|
));
|
|
}
|
|
|
|
let mut logical_row = Vec::new();
|
|
for expr in row {
|
|
// VALUES doesn't have input schema
|
|
let empty_schema = Arc::new(LogicalSchema::empty());
|
|
logical_row.push(self.build_expr(expr, &empty_schema)?);
|
|
}
|
|
rows.push(logical_row);
|
|
}
|
|
|
|
Ok(LogicalPlan::Values(Values {
|
|
rows,
|
|
schema: Arc::new(LogicalSchema::new(schema_columns)),
|
|
}))
|
|
}
|
|
|
|
// Build SORT
|
|
fn build_sort(
|
|
&mut self,
|
|
input: LogicalPlan,
|
|
exprs: &[ast::SortedColumn],
|
|
) -> Result<LogicalPlan> {
|
|
let input_schema = input.schema();
|
|
let mut sort_exprs = Vec::new();
|
|
|
|
for sorted_col in exprs {
|
|
let expr = self.build_expr(&sorted_col.expr, input_schema)?;
|
|
sort_exprs.push(SortExpr {
|
|
expr,
|
|
asc: sorted_col.order != Some(ast::SortOrder::Desc),
|
|
nulls_first: sorted_col.nulls == Some(ast::NullsOrder::First),
|
|
});
|
|
}
|
|
|
|
Ok(LogicalPlan::Sort(Sort {
|
|
input: Arc::new(input),
|
|
exprs: sort_exprs,
|
|
}))
|
|
}
|
|
|
|
// Build LIMIT
|
|
fn build_limit(input: LogicalPlan, limit: &ast::Limit) -> Result<LogicalPlan> {
|
|
let fetch = match limit.expr.as_ref() {
|
|
ast::Expr::Literal(ast::Literal::Numeric(s)) => s.parse::<usize>().ok(),
|
|
_ => {
|
|
return Err(LimboError::ParseError(
|
|
"LIMIT must be a literal integer".to_string(),
|
|
))
|
|
}
|
|
};
|
|
|
|
let skip = if let Some(offset) = &limit.offset {
|
|
match offset.as_ref() {
|
|
ast::Expr::Literal(ast::Literal::Numeric(s)) => s.parse::<usize>().ok(),
|
|
_ => {
|
|
return Err(LimboError::ParseError(
|
|
"OFFSET must be a literal integer".to_string(),
|
|
))
|
|
}
|
|
}
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Ok(LogicalPlan::Limit(Limit {
|
|
input: Arc::new(input),
|
|
skip,
|
|
fetch,
|
|
}))
|
|
}
|
|
|
|
// Build compound operator (UNION, INTERSECT, EXCEPT)
|
|
fn build_compound(
|
|
left: LogicalPlan,
|
|
right: LogicalPlan,
|
|
op: &ast::CompoundOperator,
|
|
) -> Result<LogicalPlan> {
|
|
// Check schema compatibility
|
|
if left.schema().column_count() != right.schema().column_count() {
|
|
return Err(LimboError::ParseError(
|
|
"UNION/INTERSECT/EXCEPT requires same number of columns".to_string(),
|
|
));
|
|
}
|
|
|
|
let all = matches!(op, ast::CompoundOperator::UnionAll);
|
|
|
|
match op {
|
|
ast::CompoundOperator::Union | ast::CompoundOperator::UnionAll => {
|
|
let schema = left.schema().clone();
|
|
Ok(LogicalPlan::Union(Union {
|
|
inputs: vec![Arc::new(left), Arc::new(right)],
|
|
all,
|
|
schema,
|
|
}))
|
|
}
|
|
_ => Err(LimboError::ParseError(
|
|
"INTERSECT and EXCEPT not yet supported in logical plans".to_string(),
|
|
)),
|
|
}
|
|
}
|
|
|
|
// Build expression from AST
|
|
fn build_expr(&mut self, expr: &ast::Expr, _schema: &SchemaRef) -> Result<LogicalExpr> {
|
|
match expr {
|
|
ast::Expr::Id(name) => Ok(LogicalExpr::Column(Column::new(Self::name_to_string(name)))),
|
|
|
|
ast::Expr::DoublyQualified(db, table, col) => {
|
|
Ok(LogicalExpr::Column(Column::with_table(
|
|
Self::name_to_string(col),
|
|
format!(
|
|
"{}.{}",
|
|
Self::name_to_string(db),
|
|
Self::name_to_string(table)
|
|
),
|
|
)))
|
|
}
|
|
|
|
ast::Expr::Qualified(table, col) => Ok(LogicalExpr::Column(Column::with_table(
|
|
Self::name_to_string(col),
|
|
Self::name_to_string(table),
|
|
))),
|
|
|
|
ast::Expr::Literal(lit) => Ok(LogicalExpr::Literal(Self::build_literal(lit)?)),
|
|
|
|
ast::Expr::Binary(lhs, op, rhs) => {
|
|
// Special case: IS NULL and IS NOT NULL
|
|
if matches!(op, ast::Operator::Is | ast::Operator::IsNot) {
|
|
if let ast::Expr::Literal(ast::Literal::Null) = rhs.as_ref() {
|
|
let expr = Box::new(self.build_expr(lhs, _schema)?);
|
|
return Ok(LogicalExpr::IsNull {
|
|
expr,
|
|
negated: matches!(op, ast::Operator::IsNot),
|
|
});
|
|
}
|
|
}
|
|
|
|
let left = Box::new(self.build_expr(lhs, _schema)?);
|
|
let right = Box::new(self.build_expr(rhs, _schema)?);
|
|
Ok(LogicalExpr::BinaryExpr {
|
|
left,
|
|
op: *op,
|
|
right,
|
|
})
|
|
}
|
|
|
|
ast::Expr::Unary(op, expr) => {
|
|
let inner = Box::new(self.build_expr(expr, _schema)?);
|
|
Ok(LogicalExpr::UnaryExpr {
|
|
op: *op,
|
|
expr: inner,
|
|
})
|
|
}
|
|
|
|
ast::Expr::FunctionCall {
|
|
name,
|
|
distinctness,
|
|
args,
|
|
filter_over,
|
|
..
|
|
} => {
|
|
// Check for window functions (OVER clause)
|
|
if filter_over.over_clause.is_some() {
|
|
return Err(LimboError::ParseError(
|
|
"Unsupported expression type: window functions are not yet supported"
|
|
.to_string(),
|
|
));
|
|
}
|
|
|
|
let func_name = Self::name_to_string(name);
|
|
let arg_count = args.len();
|
|
// Check if it's an aggregate function (considering argument count for min/max)
|
|
if let Some(agg_fun) = Self::parse_aggregate_function(&func_name, arg_count) {
|
|
let distinct = distinctness.is_some();
|
|
let arg_exprs = args
|
|
.iter()
|
|
.map(|e| self.build_expr(e, _schema))
|
|
.collect::<Result<Vec<_>>>()?;
|
|
Ok(LogicalExpr::AggregateFunction {
|
|
fun: agg_fun,
|
|
args: arg_exprs,
|
|
distinct,
|
|
})
|
|
} else {
|
|
// Regular scalar function
|
|
let arg_exprs = args
|
|
.iter()
|
|
.map(|e| self.build_expr(e, _schema))
|
|
.collect::<Result<Vec<_>>>()?;
|
|
Ok(LogicalExpr::ScalarFunction {
|
|
fun: func_name,
|
|
args: arg_exprs,
|
|
})
|
|
}
|
|
}
|
|
|
|
ast::Expr::FunctionCallStar { name, .. } => {
|
|
// Handle COUNT(*) and similar
|
|
let func_name = Self::name_to_string(name);
|
|
// FunctionCallStar always has 0 args (it's the * form)
|
|
if let Some(agg_fun) = Self::parse_aggregate_function(&func_name, 0) {
|
|
Ok(LogicalExpr::AggregateFunction {
|
|
fun: agg_fun,
|
|
args: vec![],
|
|
distinct: false,
|
|
})
|
|
} else {
|
|
Err(LimboError::ParseError(format!(
|
|
"Function {func_name}(*) is not supported"
|
|
)))
|
|
}
|
|
}
|
|
|
|
ast::Expr::Case {
|
|
base,
|
|
when_then_pairs,
|
|
else_expr,
|
|
} => {
|
|
let case_expr = if let Some(e) = base {
|
|
Some(Box::new(self.build_expr(e, _schema)?))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let when_then_exprs = when_then_pairs
|
|
.iter()
|
|
.map(|(when, then)| {
|
|
Ok((
|
|
self.build_expr(when, _schema)?,
|
|
self.build_expr(then, _schema)?,
|
|
))
|
|
})
|
|
.collect::<Result<Vec<_>>>()?;
|
|
|
|
let else_result = if let Some(e) = else_expr {
|
|
Some(Box::new(self.build_expr(e, _schema)?))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Ok(LogicalExpr::Case {
|
|
expr: case_expr,
|
|
when_then: when_then_exprs,
|
|
else_expr: else_result,
|
|
})
|
|
}
|
|
|
|
ast::Expr::InList { lhs, not, rhs } => {
|
|
let expr = Box::new(self.build_expr(lhs, _schema)?);
|
|
let list = rhs
|
|
.iter()
|
|
.map(|e| self.build_expr(e, _schema))
|
|
.collect::<Result<Vec<_>>>()?;
|
|
Ok(LogicalExpr::InList {
|
|
expr,
|
|
list,
|
|
negated: *not,
|
|
})
|
|
}
|
|
|
|
ast::Expr::InSelect { lhs, not, rhs } => {
|
|
let expr = Box::new(self.build_expr(lhs, _schema)?);
|
|
let subquery = Arc::new(self.build_select(rhs)?);
|
|
Ok(LogicalExpr::InSubquery {
|
|
expr,
|
|
subquery,
|
|
negated: *not,
|
|
})
|
|
}
|
|
|
|
ast::Expr::Exists(select) => {
|
|
let subquery = Arc::new(self.build_select(select)?);
|
|
Ok(LogicalExpr::Exists {
|
|
subquery,
|
|
negated: false,
|
|
})
|
|
}
|
|
|
|
ast::Expr::Subquery(select) => {
|
|
let subquery = Arc::new(self.build_select(select)?);
|
|
Ok(LogicalExpr::ScalarSubquery(subquery))
|
|
}
|
|
|
|
ast::Expr::IsNull(lhs) => {
|
|
let expr = Box::new(self.build_expr(lhs, _schema)?);
|
|
Ok(LogicalExpr::IsNull {
|
|
expr,
|
|
negated: false,
|
|
})
|
|
}
|
|
|
|
ast::Expr::NotNull(lhs) => {
|
|
let expr = Box::new(self.build_expr(lhs, _schema)?);
|
|
Ok(LogicalExpr::IsNull {
|
|
expr,
|
|
negated: true,
|
|
})
|
|
}
|
|
|
|
ast::Expr::Between {
|
|
lhs,
|
|
not,
|
|
start,
|
|
end,
|
|
} => {
|
|
let expr = Box::new(self.build_expr(lhs, _schema)?);
|
|
let low = Box::new(self.build_expr(start, _schema)?);
|
|
let high = Box::new(self.build_expr(end, _schema)?);
|
|
Ok(LogicalExpr::Between {
|
|
expr,
|
|
low,
|
|
high,
|
|
negated: *not,
|
|
})
|
|
}
|
|
|
|
ast::Expr::Like {
|
|
lhs,
|
|
not,
|
|
op: _,
|
|
rhs,
|
|
escape,
|
|
} => {
|
|
let expr = Box::new(self.build_expr(lhs, _schema)?);
|
|
let pattern = Box::new(self.build_expr(rhs, _schema)?);
|
|
let escape_char = escape.as_ref().and_then(|e| {
|
|
if let ast::Expr::Literal(ast::Literal::String(s)) = e.as_ref() {
|
|
s.chars().next()
|
|
} else {
|
|
None
|
|
}
|
|
});
|
|
Ok(LogicalExpr::Like {
|
|
expr,
|
|
pattern,
|
|
escape: escape_char,
|
|
negated: *not,
|
|
})
|
|
}
|
|
|
|
ast::Expr::Parenthesized(exprs) => {
|
|
// the assumption is that there is at least one parenthesis here.
|
|
// If this is not true, then I don't understand this code and can't be trusted.
|
|
assert!(!exprs.is_empty());
|
|
// Multiple expressions in parentheses is unusual but handle it
|
|
// by building the first one (SQLite behavior)
|
|
self.build_expr(&exprs[0], _schema)
|
|
}
|
|
|
|
ast::Expr::Cast { expr, type_name } => {
|
|
let inner = self.build_expr(expr, _schema)?;
|
|
Ok(LogicalExpr::Cast {
|
|
expr: Box::new(inner),
|
|
type_name: type_name.clone(),
|
|
})
|
|
}
|
|
|
|
_ => Err(LimboError::ParseError(format!(
|
|
"Unsupported expression type in logical plan: {expr:?}"
|
|
))),
|
|
}
|
|
}
|
|
|
|
/// Build literal value
|
|
fn build_literal(lit: &ast::Literal) -> Result<Value> {
|
|
match lit {
|
|
ast::Literal::Null => Ok(Value::Null),
|
|
ast::Literal::Keyword(k) => {
|
|
let k_bytes = k.as_bytes();
|
|
match_ignore_ascii_case!(match k_bytes {
|
|
b"true" => Ok(Value::Integer(1)), // SQLite uses int for bool
|
|
b"false" => Ok(Value::Integer(0)), // SQLite uses int for bool
|
|
_ => Ok(Value::Text(k.clone().into())),
|
|
})
|
|
}
|
|
ast::Literal::Numeric(s) => {
|
|
if let Ok(i) = s.parse::<i64>() {
|
|
Ok(Value::Integer(i))
|
|
} else if let Ok(f) = s.parse::<f64>() {
|
|
Ok(Value::Float(f))
|
|
} else {
|
|
Ok(Value::Text(s.clone().into()))
|
|
}
|
|
}
|
|
ast::Literal::String(s) => {
|
|
// Strip surrounding quotes from the SQL literal
|
|
// The parser includes quotes in the string value
|
|
let unquoted = if s.starts_with('\'') && s.ends_with('\'') && s.len() > 1 {
|
|
&s[1..s.len() - 1]
|
|
} else {
|
|
s.as_str()
|
|
};
|
|
Ok(Value::Text(unquoted.to_string().into()))
|
|
}
|
|
ast::Literal::Blob(b) => Ok(Value::Blob(b.clone().into())),
|
|
ast::Literal::CurrentDate
|
|
| ast::Literal::CurrentTime
|
|
| ast::Literal::CurrentTimestamp => Err(LimboError::ParseError(
|
|
"Temporal literals not yet supported".to_string(),
|
|
)),
|
|
}
|
|
}
|
|
|
|
/// Parse aggregate function name (considering argument count for min/max)
|
|
fn parse_aggregate_function(name: &str, arg_count: usize) -> Option<AggregateFunction> {
|
|
let name_bytes = name.as_bytes();
|
|
match_ignore_ascii_case!(match name_bytes {
|
|
b"COUNT" => Some(AggFunc::Count),
|
|
b"SUM" => Some(AggFunc::Sum),
|
|
b"AVG" => Some(AggFunc::Avg),
|
|
// MIN and MAX are only aggregates with 1 argument
|
|
// With 2+ arguments, they're scalar functions
|
|
b"MIN" if arg_count == 1 => Some(AggFunc::Min),
|
|
b"MAX" if arg_count == 1 => Some(AggFunc::Max),
|
|
b"GROUP_CONCAT" => Some(AggFunc::GroupConcat),
|
|
b"STRING_AGG" => Some(AggFunc::StringAgg),
|
|
b"TOTAL" => Some(AggFunc::Total),
|
|
_ => None,
|
|
})
|
|
}
|
|
|
|
// Check if expression contains aggregates
|
|
fn has_aggregates(columns: &[ast::ResultColumn]) -> bool {
|
|
for col in columns {
|
|
if let ast::ResultColumn::Expr(expr, _) = col {
|
|
if Self::expr_has_aggregate(expr) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
false
|
|
}
|
|
|
|
// Check if AST expression contains aggregates
|
|
fn expr_has_aggregate(expr: &ast::Expr) -> bool {
|
|
match expr {
|
|
ast::Expr::FunctionCall { name, args, .. } => {
|
|
// Check if the function itself is an aggregate (considering arg count for min/max)
|
|
let arg_count = args.len();
|
|
if Self::parse_aggregate_function(&Self::name_to_string(name), arg_count).is_some()
|
|
{
|
|
return true;
|
|
}
|
|
// Also check if any arguments contain aggregates (for nested functions like HEX(SUM(...)))
|
|
args.iter().any(|arg| Self::expr_has_aggregate(arg))
|
|
}
|
|
ast::Expr::FunctionCallStar { name, .. } => {
|
|
// FunctionCallStar always has 0 args
|
|
Self::parse_aggregate_function(&Self::name_to_string(name), 0).is_some()
|
|
}
|
|
ast::Expr::Binary(lhs, _, rhs) => {
|
|
Self::expr_has_aggregate(lhs) || Self::expr_has_aggregate(rhs)
|
|
}
|
|
ast::Expr::Unary(_, e) => Self::expr_has_aggregate(e),
|
|
ast::Expr::Case {
|
|
when_then_pairs,
|
|
else_expr,
|
|
..
|
|
} => {
|
|
when_then_pairs
|
|
.iter()
|
|
.any(|(w, t)| Self::expr_has_aggregate(w) || Self::expr_has_aggregate(t))
|
|
|| else_expr
|
|
.as_ref()
|
|
.is_some_and(|e| Self::expr_has_aggregate(e))
|
|
}
|
|
ast::Expr::Parenthesized(exprs) => {
|
|
// Check if any parenthesized expression contains an aggregate
|
|
exprs.iter().any(|e| Self::expr_has_aggregate(e))
|
|
}
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
// Check if logical expression is an aggregate
|
|
fn is_aggregate_expr(expr: &LogicalExpr) -> bool {
|
|
match expr {
|
|
LogicalExpr::AggregateFunction { .. } => true,
|
|
LogicalExpr::Alias { expr, .. } => Self::is_aggregate_expr(expr),
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
// Check if logical expression contains an aggregate anywhere
|
|
fn contains_aggregate(expr: &LogicalExpr) -> bool {
|
|
match expr {
|
|
LogicalExpr::AggregateFunction { .. } => true,
|
|
LogicalExpr::Alias { expr, .. } => Self::contains_aggregate(expr),
|
|
LogicalExpr::BinaryExpr { left, right, .. } => {
|
|
Self::contains_aggregate(left) || Self::contains_aggregate(right)
|
|
}
|
|
LogicalExpr::UnaryExpr { expr, .. } => Self::contains_aggregate(expr),
|
|
LogicalExpr::ScalarFunction { args, .. } => args.iter().any(Self::contains_aggregate),
|
|
LogicalExpr::Case {
|
|
when_then,
|
|
else_expr,
|
|
..
|
|
} => {
|
|
when_then
|
|
.iter()
|
|
.any(|(w, t)| Self::contains_aggregate(w) || Self::contains_aggregate(t))
|
|
|| else_expr
|
|
.as_ref()
|
|
.is_some_and(|e| Self::contains_aggregate(e))
|
|
}
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
// Check if an expression is a constant (contains only literals)
|
|
fn is_constant_expr(expr: &LogicalExpr) -> bool {
|
|
match expr {
|
|
LogicalExpr::Literal(_) => true,
|
|
LogicalExpr::BinaryExpr { left, right, .. } => {
|
|
Self::is_constant_expr(left) && Self::is_constant_expr(right)
|
|
}
|
|
LogicalExpr::UnaryExpr { expr, .. } => Self::is_constant_expr(expr),
|
|
LogicalExpr::ScalarFunction { args, .. } => args.iter().all(Self::is_constant_expr),
|
|
LogicalExpr::Alias { expr, .. } => Self::is_constant_expr(expr),
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
// Check if an expression is valid in GROUP BY context
|
|
// An expression is valid if it's:
|
|
// 1. A constant literal
|
|
// 2. An aggregate function
|
|
// 3. A grouping column (or expression involving only grouping columns)
|
|
fn is_valid_in_group_by(expr: &LogicalExpr, group_exprs: &[LogicalExpr]) -> bool {
|
|
match expr {
|
|
LogicalExpr::Literal(_) => true, // Constants are always valid
|
|
LogicalExpr::AggregateFunction { .. } => true, // Aggregates are valid
|
|
LogicalExpr::Column(col) => {
|
|
// Check if this column is in the GROUP BY
|
|
group_exprs.iter().any(|g| match g {
|
|
LogicalExpr::Column(gcol) => gcol.name == col.name,
|
|
_ => false,
|
|
})
|
|
}
|
|
LogicalExpr::BinaryExpr { left, right, .. } => {
|
|
// Both sides must be valid
|
|
Self::is_valid_in_group_by(left, group_exprs)
|
|
&& Self::is_valid_in_group_by(right, group_exprs)
|
|
}
|
|
LogicalExpr::UnaryExpr { expr, .. } => Self::is_valid_in_group_by(expr, group_exprs),
|
|
LogicalExpr::ScalarFunction { args, .. } => {
|
|
// All arguments must be valid
|
|
args.iter()
|
|
.all(|arg| Self::is_valid_in_group_by(arg, group_exprs))
|
|
}
|
|
LogicalExpr::Alias { expr, .. } => Self::is_valid_in_group_by(expr, group_exprs),
|
|
_ => false, // Other expressions are not valid
|
|
}
|
|
}
|
|
|
|
// Extract aggregates from an expression and replace them with column references, with deduplication
|
|
// Returns the modified expression and a list of NEW (aggregate_expr, column_name) pairs
|
|
fn extract_and_replace_aggregates_with_dedup(
|
|
expr: LogicalExpr,
|
|
aggregate_map: &mut std::collections::HashMap<String, String>,
|
|
) -> Result<(LogicalExpr, Vec<(LogicalExpr, String)>)> {
|
|
let mut new_aggregates = Vec::new();
|
|
let mut counter = aggregate_map.len();
|
|
let new_expr = Self::replace_aggregates_with_columns_dedup(
|
|
expr,
|
|
&mut new_aggregates,
|
|
aggregate_map,
|
|
&mut counter,
|
|
)?;
|
|
Ok((new_expr, new_aggregates))
|
|
}
|
|
|
|
// Recursively replace aggregate functions with column references, with deduplication
|
|
fn replace_aggregates_with_columns_dedup(
|
|
expr: LogicalExpr,
|
|
new_aggregates: &mut Vec<(LogicalExpr, String)>,
|
|
aggregate_map: &mut std::collections::HashMap<String, String>,
|
|
counter: &mut usize,
|
|
) -> Result<LogicalExpr> {
|
|
match expr {
|
|
LogicalExpr::AggregateFunction { .. } => {
|
|
// Found an aggregate - check if we've seen it before
|
|
let agg_key = format!("{expr:?}");
|
|
|
|
let col_name = if let Some(existing_name) = aggregate_map.get(&agg_key) {
|
|
// Reuse existing aggregate
|
|
existing_name.clone()
|
|
} else {
|
|
// New aggregate
|
|
let col_name = format!("__agg_{}", *counter);
|
|
*counter += 1;
|
|
aggregate_map.insert(agg_key, col_name.clone());
|
|
new_aggregates.push((expr, col_name.clone()));
|
|
col_name
|
|
};
|
|
|
|
Ok(LogicalExpr::Column(Column {
|
|
name: col_name,
|
|
table: None,
|
|
}))
|
|
}
|
|
LogicalExpr::BinaryExpr { left, op, right } => {
|
|
let new_left = Self::replace_aggregates_with_columns_dedup(
|
|
*left,
|
|
new_aggregates,
|
|
aggregate_map,
|
|
counter,
|
|
)?;
|
|
let new_right = Self::replace_aggregates_with_columns_dedup(
|
|
*right,
|
|
new_aggregates,
|
|
aggregate_map,
|
|
counter,
|
|
)?;
|
|
Ok(LogicalExpr::BinaryExpr {
|
|
left: Box::new(new_left),
|
|
op,
|
|
right: Box::new(new_right),
|
|
})
|
|
}
|
|
LogicalExpr::UnaryExpr { op, expr } => {
|
|
let new_expr = Self::replace_aggregates_with_columns_dedup(
|
|
*expr,
|
|
new_aggregates,
|
|
aggregate_map,
|
|
counter,
|
|
)?;
|
|
Ok(LogicalExpr::UnaryExpr {
|
|
op,
|
|
expr: Box::new(new_expr),
|
|
})
|
|
}
|
|
LogicalExpr::ScalarFunction { fun, args } => {
|
|
let mut new_args = Vec::new();
|
|
for arg in args {
|
|
new_args.push(Self::replace_aggregates_with_columns_dedup(
|
|
arg,
|
|
new_aggregates,
|
|
aggregate_map,
|
|
counter,
|
|
)?);
|
|
}
|
|
Ok(LogicalExpr::ScalarFunction {
|
|
fun,
|
|
args: new_args,
|
|
})
|
|
}
|
|
LogicalExpr::Case {
|
|
expr: case_expr,
|
|
when_then,
|
|
else_expr,
|
|
} => {
|
|
let new_case_expr = if let Some(e) = case_expr {
|
|
Some(Box::new(Self::replace_aggregates_with_columns_dedup(
|
|
*e,
|
|
new_aggregates,
|
|
aggregate_map,
|
|
counter,
|
|
)?))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let mut new_when_then = Vec::new();
|
|
for (when, then) in when_then {
|
|
let new_when = Self::replace_aggregates_with_columns_dedup(
|
|
when,
|
|
new_aggregates,
|
|
aggregate_map,
|
|
counter,
|
|
)?;
|
|
let new_then = Self::replace_aggregates_with_columns_dedup(
|
|
then,
|
|
new_aggregates,
|
|
aggregate_map,
|
|
counter,
|
|
)?;
|
|
new_when_then.push((new_when, new_then));
|
|
}
|
|
|
|
let new_else = if let Some(e) = else_expr {
|
|
Some(Box::new(Self::replace_aggregates_with_columns_dedup(
|
|
*e,
|
|
new_aggregates,
|
|
aggregate_map,
|
|
counter,
|
|
)?))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Ok(LogicalExpr::Case {
|
|
expr: new_case_expr,
|
|
when_then: new_when_then,
|
|
else_expr: new_else,
|
|
})
|
|
}
|
|
LogicalExpr::Alias { expr, alias } => {
|
|
let new_expr = Self::replace_aggregates_with_columns_dedup(
|
|
*expr,
|
|
new_aggregates,
|
|
aggregate_map,
|
|
counter,
|
|
)?;
|
|
Ok(LogicalExpr::Alias {
|
|
expr: Box::new(new_expr),
|
|
alias,
|
|
})
|
|
}
|
|
// Other expressions - keep as is
|
|
_ => Ok(expr),
|
|
}
|
|
}
|
|
|
|
// Get column name from expression
|
|
fn expr_to_column_name(expr: &ast::Expr) -> String {
|
|
match expr {
|
|
ast::Expr::Id(name) => Self::name_to_string(name),
|
|
ast::Expr::Qualified(_, col) => Self::name_to_string(col),
|
|
ast::Expr::FunctionCall { name, .. } => Self::name_to_string(name),
|
|
ast::Expr::FunctionCallStar { name, .. } => {
|
|
format!("{}(*)", Self::name_to_string(name))
|
|
}
|
|
_ => "expr".to_string(),
|
|
}
|
|
}
|
|
|
|
// Get table schema
|
|
fn get_table_schema(&self, table_name: &str, alias: Option<&str>) -> Result<SchemaRef> {
|
|
// Look up table in schema
|
|
let table = self
|
|
.schema
|
|
.get_table(table_name)
|
|
.ok_or_else(|| LimboError::ParseError(format!("Table '{table_name}' not found")))?;
|
|
|
|
// Parse table_name which might be "db.table" for attached databases
|
|
let (database, actual_table) = if table_name.contains('.') {
|
|
let parts: Vec<&str> = table_name.splitn(2, '.').collect();
|
|
(Some(parts[0].to_string()), parts[1].to_string())
|
|
} else {
|
|
(None, table_name.to_string())
|
|
};
|
|
|
|
let mut columns = Vec::new();
|
|
for col in table.columns() {
|
|
if let Some(ref name) = col.name {
|
|
columns.push(ColumnInfo {
|
|
name: name.clone(),
|
|
ty: col.ty,
|
|
database: database.clone(),
|
|
table: Some(actual_table.clone()),
|
|
table_alias: alias.map(|s| s.to_string()),
|
|
});
|
|
}
|
|
}
|
|
|
|
Ok(Arc::new(LogicalSchema::new(columns)))
|
|
}
|
|
|
|
// Infer expression type
|
|
fn infer_expr_type(expr: &LogicalExpr, schema: &SchemaRef) -> Result<Type> {
|
|
match expr {
|
|
LogicalExpr::Column(col) => {
|
|
if let Some((_, col_info)) = schema.find_column(&col.name, col.table.as_deref()) {
|
|
Ok(col_info.ty)
|
|
} else {
|
|
Ok(Type::Text)
|
|
}
|
|
}
|
|
LogicalExpr::Literal(Value::Integer(_)) => Ok(Type::Integer),
|
|
LogicalExpr::Literal(Value::Float(_)) => Ok(Type::Real),
|
|
LogicalExpr::Literal(Value::Text(_)) => Ok(Type::Text),
|
|
LogicalExpr::Literal(Value::Null) => Ok(Type::Null),
|
|
LogicalExpr::Literal(Value::Blob(_)) => Ok(Type::Blob),
|
|
LogicalExpr::BinaryExpr { op, left, right } => {
|
|
match op {
|
|
ast::Operator::Add | ast::Operator::Subtract | ast::Operator::Multiply => {
|
|
// Infer types of operands to match SQLite/Numeric behavior
|
|
let left_type = Self::infer_expr_type(left, schema)?;
|
|
let right_type = Self::infer_expr_type(right, schema)?;
|
|
|
|
// Integer op Integer = Integer (matching core/numeric/mod.rs behavior)
|
|
// Any operation with Real = Real
|
|
match (left_type, right_type) {
|
|
(Type::Integer, Type::Integer) => Ok(Type::Integer),
|
|
(Type::Integer, Type::Real)
|
|
| (Type::Real, Type::Integer)
|
|
| (Type::Real, Type::Real) => Ok(Type::Real),
|
|
(Type::Null, _) | (_, Type::Null) => Ok(Type::Null),
|
|
// For Text/Blob, SQLite coerces to numeric, defaulting to Real
|
|
_ => Ok(Type::Real),
|
|
}
|
|
}
|
|
ast::Operator::Divide => {
|
|
// Division always produces Real in SQLite
|
|
Ok(Type::Real)
|
|
}
|
|
ast::Operator::Modulus => {
|
|
// Modulus follows same rules as other arithmetic ops
|
|
let left_type = Self::infer_expr_type(left, schema)?;
|
|
let right_type = Self::infer_expr_type(right, schema)?;
|
|
match (left_type, right_type) {
|
|
(Type::Integer, Type::Integer) => Ok(Type::Integer),
|
|
_ => Ok(Type::Real),
|
|
}
|
|
}
|
|
ast::Operator::Equals
|
|
| ast::Operator::NotEquals
|
|
| ast::Operator::Less
|
|
| ast::Operator::LessEquals
|
|
| ast::Operator::Greater
|
|
| ast::Operator::GreaterEquals
|
|
| ast::Operator::And
|
|
| ast::Operator::Or
|
|
| ast::Operator::Is
|
|
| ast::Operator::IsNot => Ok(Type::Integer),
|
|
ast::Operator::Concat => Ok(Type::Text),
|
|
_ => Ok(Type::Text), // Default for other operators
|
|
}
|
|
}
|
|
LogicalExpr::UnaryExpr { op, expr } => match op {
|
|
ast::UnaryOperator::Not => Ok(Type::Integer),
|
|
ast::UnaryOperator::Negative | ast::UnaryOperator::Positive => {
|
|
Self::infer_expr_type(expr, schema)
|
|
}
|
|
ast::UnaryOperator::BitwiseNot => Ok(Type::Integer),
|
|
},
|
|
LogicalExpr::AggregateFunction { fun, .. } => match fun {
|
|
AggFunc::Count | AggFunc::Count0 => Ok(Type::Integer),
|
|
AggFunc::Sum | AggFunc::Avg | AggFunc::Total => Ok(Type::Real),
|
|
AggFunc::Min | AggFunc::Max => Ok(Type::Text),
|
|
AggFunc::GroupConcat | AggFunc::StringAgg => Ok(Type::Text),
|
|
#[cfg(feature = "json")]
|
|
AggFunc::JsonbGroupArray
|
|
| AggFunc::JsonGroupArray
|
|
| AggFunc::JsonbGroupObject
|
|
| AggFunc::JsonGroupObject => Ok(Type::Text),
|
|
AggFunc::External(_) => Ok(Type::Text), // Default for external
|
|
},
|
|
LogicalExpr::Alias { expr, .. } => Self::infer_expr_type(expr, schema),
|
|
LogicalExpr::IsNull { .. } => Ok(Type::Integer),
|
|
LogicalExpr::InList { .. } | LogicalExpr::InSubquery { .. } => Ok(Type::Integer),
|
|
LogicalExpr::Exists { .. } => Ok(Type::Integer),
|
|
LogicalExpr::Between { .. } => Ok(Type::Integer),
|
|
LogicalExpr::Like { .. } => Ok(Type::Integer),
|
|
_ => Ok(Type::Text),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::schema::{BTreeTable, Column as SchemaColumn, Schema, Type};
|
|
use turso_parser::parser::Parser;
|
|
|
|
fn create_test_schema() -> Schema {
|
|
let mut schema = Schema::new(false);
|
|
|
|
// Create users table
|
|
let users_table = BTreeTable {
|
|
name: "users".to_string(),
|
|
root_page: 2,
|
|
primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)],
|
|
columns: vec![
|
|
SchemaColumn {
|
|
name: Some("id".to_string()),
|
|
ty: Type::Integer,
|
|
ty_str: "INTEGER".to_string(),
|
|
primary_key: true,
|
|
is_rowid_alias: true,
|
|
notnull: true,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
SchemaColumn {
|
|
name: Some("name".to_string()),
|
|
ty: Type::Text,
|
|
ty_str: "TEXT".to_string(),
|
|
primary_key: false,
|
|
is_rowid_alias: false,
|
|
notnull: false,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
SchemaColumn {
|
|
name: Some("age".to_string()),
|
|
ty: Type::Integer,
|
|
ty_str: "INTEGER".to_string(),
|
|
primary_key: false,
|
|
is_rowid_alias: false,
|
|
notnull: false,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
SchemaColumn {
|
|
name: Some("email".to_string()),
|
|
ty: Type::Text,
|
|
ty_str: "TEXT".to_string(),
|
|
primary_key: false,
|
|
is_rowid_alias: false,
|
|
notnull: false,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
],
|
|
has_rowid: true,
|
|
is_strict: false,
|
|
has_autoincrement: false,
|
|
unique_sets: vec![],
|
|
};
|
|
schema.add_btree_table(Arc::new(users_table));
|
|
|
|
// Create orders table
|
|
let orders_table = BTreeTable {
|
|
name: "orders".to_string(),
|
|
root_page: 3,
|
|
primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)],
|
|
columns: vec![
|
|
SchemaColumn {
|
|
name: Some("id".to_string()),
|
|
ty: Type::Integer,
|
|
ty_str: "INTEGER".to_string(),
|
|
primary_key: true,
|
|
is_rowid_alias: true,
|
|
notnull: true,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
SchemaColumn {
|
|
name: Some("user_id".to_string()),
|
|
ty: Type::Integer,
|
|
ty_str: "INTEGER".to_string(),
|
|
primary_key: false,
|
|
is_rowid_alias: false,
|
|
notnull: false,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
SchemaColumn {
|
|
name: Some("product".to_string()),
|
|
ty: Type::Text,
|
|
ty_str: "TEXT".to_string(),
|
|
primary_key: false,
|
|
is_rowid_alias: false,
|
|
notnull: false,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
SchemaColumn {
|
|
name: Some("amount".to_string()),
|
|
ty: Type::Real,
|
|
ty_str: "REAL".to_string(),
|
|
primary_key: false,
|
|
is_rowid_alias: false,
|
|
notnull: false,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
],
|
|
has_rowid: true,
|
|
is_strict: false,
|
|
has_autoincrement: false,
|
|
unique_sets: vec![],
|
|
};
|
|
schema.add_btree_table(Arc::new(orders_table));
|
|
|
|
// Create products table
|
|
let products_table = BTreeTable {
|
|
name: "products".to_string(),
|
|
root_page: 4,
|
|
primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)],
|
|
columns: vec![
|
|
SchemaColumn {
|
|
name: Some("id".to_string()),
|
|
ty: Type::Integer,
|
|
ty_str: "INTEGER".to_string(),
|
|
primary_key: true,
|
|
is_rowid_alias: true,
|
|
notnull: true,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
SchemaColumn {
|
|
name: Some("name".to_string()),
|
|
ty: Type::Text,
|
|
ty_str: "TEXT".to_string(),
|
|
primary_key: false,
|
|
is_rowid_alias: false,
|
|
notnull: false,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
SchemaColumn {
|
|
name: Some("price".to_string()),
|
|
ty: Type::Real,
|
|
ty_str: "REAL".to_string(),
|
|
primary_key: false,
|
|
is_rowid_alias: false,
|
|
notnull: false,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
SchemaColumn {
|
|
name: Some("product_id".to_string()),
|
|
ty: Type::Integer,
|
|
ty_str: "INTEGER".to_string(),
|
|
primary_key: false,
|
|
is_rowid_alias: false,
|
|
notnull: false,
|
|
default: None,
|
|
unique: false,
|
|
collation: None,
|
|
hidden: false,
|
|
},
|
|
],
|
|
has_rowid: true,
|
|
is_strict: false,
|
|
has_autoincrement: false,
|
|
unique_sets: vec![],
|
|
};
|
|
schema.add_btree_table(Arc::new(products_table));
|
|
|
|
schema
|
|
}
|
|
|
|
fn parse_and_build(sql: &str, schema: &Schema) -> Result<LogicalPlan> {
|
|
let mut parser = Parser::new(sql.as_bytes());
|
|
let cmd = parser
|
|
.next()
|
|
.ok_or_else(|| LimboError::ParseError("Empty statement".to_string()))?
|
|
.map_err(|e| LimboError::ParseError(e.to_string()))?;
|
|
match cmd {
|
|
ast::Cmd::Stmt(stmt) => {
|
|
let mut builder = LogicalPlanBuilder::new(schema);
|
|
builder.build_statement(&stmt)
|
|
}
|
|
_ => Err(LimboError::ParseError(
|
|
"Only SQL statements are supported".to_string(),
|
|
)),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_simple_select() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT id, name FROM users";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 2);
|
|
assert!(matches!(proj.exprs[0], LogicalExpr::Column(_)));
|
|
assert!(matches!(proj.exprs[1], LogicalExpr::Column(_)));
|
|
|
|
match &*proj.input {
|
|
LogicalPlan::TableScan(scan) => {
|
|
assert_eq!(scan.table_name, "users");
|
|
}
|
|
_ => panic!("Expected TableScan"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_select_with_filter() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT name FROM users WHERE age > 18";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 1);
|
|
|
|
match &*proj.input {
|
|
LogicalPlan::Filter(filter) => {
|
|
assert!(matches!(
|
|
filter.predicate,
|
|
LogicalExpr::BinaryExpr {
|
|
op: ast::Operator::Greater,
|
|
..
|
|
}
|
|
));
|
|
|
|
match &*filter.input {
|
|
LogicalPlan::TableScan(scan) => {
|
|
assert_eq!(scan.table_name, "users");
|
|
}
|
|
_ => panic!("Expected TableScan"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Filter"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_aggregate_with_group_by() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT user_id, SUM(amount) FROM orders GROUP BY user_id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.group_expr.len(), 1);
|
|
assert_eq!(agg.aggr_expr.len(), 1);
|
|
assert_eq!(agg.schema.column_count(), 2);
|
|
|
|
assert!(matches!(
|
|
agg.aggr_expr[0],
|
|
LogicalExpr::AggregateFunction {
|
|
fun: AggFunc::Sum,
|
|
..
|
|
}
|
|
));
|
|
|
|
match &*agg.input {
|
|
LogicalPlan::TableScan(scan) => {
|
|
assert_eq!(scan.table_name, "orders");
|
|
}
|
|
_ => panic!("Expected TableScan"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Aggregate (no projection)"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_aggregate_without_group_by() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT COUNT(*), MAX(age) FROM users";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.group_expr.len(), 0);
|
|
assert_eq!(agg.aggr_expr.len(), 2);
|
|
assert_eq!(agg.schema.column_count(), 2);
|
|
|
|
assert!(matches!(
|
|
agg.aggr_expr[0],
|
|
LogicalExpr::AggregateFunction {
|
|
fun: AggFunc::Count,
|
|
..
|
|
}
|
|
));
|
|
|
|
assert!(matches!(
|
|
agg.aggr_expr[1],
|
|
LogicalExpr::AggregateFunction {
|
|
fun: AggFunc::Max,
|
|
..
|
|
}
|
|
));
|
|
}
|
|
_ => panic!("Expected Aggregate (no projection)"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_order_by() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT name FROM users ORDER BY age DESC, name ASC";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Sort(sort) => {
|
|
assert_eq!(sort.exprs.len(), 2);
|
|
assert!(!sort.exprs[0].asc); // DESC
|
|
assert!(sort.exprs[1].asc); // ASC
|
|
|
|
match &*sort.input {
|
|
LogicalPlan::Projection(_) => {}
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Sort"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_limit_offset() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users LIMIT 10 OFFSET 5";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Limit(limit) => {
|
|
assert_eq!(limit.fetch, Some(10));
|
|
assert_eq!(limit.skip, Some(5));
|
|
}
|
|
_ => panic!("Expected Limit"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_order_by_with_limit() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT name FROM users ORDER BY age DESC LIMIT 5";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
// Should produce: Limit -> Sort -> Projection -> TableScan
|
|
match plan {
|
|
LogicalPlan::Limit(limit) => {
|
|
assert_eq!(limit.fetch, Some(5));
|
|
assert_eq!(limit.skip, None);
|
|
|
|
match &*limit.input {
|
|
LogicalPlan::Sort(sort) => {
|
|
assert_eq!(sort.exprs.len(), 1);
|
|
assert!(!sort.exprs[0].asc); // DESC
|
|
|
|
match &*sort.input {
|
|
LogicalPlan::Projection(_) => {}
|
|
_ => panic!("Expected Projection under Sort"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Sort under Limit"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Limit at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_distinct() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT DISTINCT name FROM users";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Distinct(distinct) => match &*distinct.input {
|
|
LogicalPlan::Projection(_) => {}
|
|
_ => panic!("Expected Projection"),
|
|
},
|
|
_ => panic!("Expected Distinct"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_union() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT id FROM users UNION SELECT user_id FROM orders";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Union(union) => {
|
|
assert!(!union.all);
|
|
assert_eq!(union.inputs.len(), 2);
|
|
}
|
|
_ => panic!("Expected Union"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_union_all() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT id FROM users UNION ALL SELECT user_id FROM orders";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Union(union) => {
|
|
assert!(union.all);
|
|
assert_eq!(union.inputs.len(), 2);
|
|
}
|
|
_ => panic!("Expected Union"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_union_with_order_by() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT id, name FROM users UNION SELECT user_id, name FROM orders ORDER BY id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Sort(sort) => {
|
|
assert_eq!(sort.exprs.len(), 1);
|
|
assert!(sort.exprs[0].asc); // Default ASC
|
|
|
|
match &*sort.input {
|
|
LogicalPlan::Union(union) => {
|
|
assert!(!union.all); // UNION (not UNION ALL)
|
|
assert_eq!(union.inputs.len(), 2);
|
|
}
|
|
_ => panic!("Expected Union under Sort"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Sort at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_with_cte() {
|
|
let schema = create_test_schema();
|
|
let sql = "WITH active_users AS (SELECT * FROM users WHERE age > 18) SELECT name FROM active_users";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::WithCTE(with) => {
|
|
assert_eq!(with.ctes.len(), 1);
|
|
assert!(with.ctes.contains_key("active_users"));
|
|
|
|
let cte = &with.ctes["active_users"];
|
|
match &**cte {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Filter(_) => {}
|
|
_ => panic!("Expected Filter in CTE"),
|
|
},
|
|
_ => panic!("Expected Projection in CTE"),
|
|
}
|
|
|
|
match &*with.body {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::CTERef(cte_ref) => {
|
|
assert_eq!(cte_ref.name, "active_users");
|
|
}
|
|
_ => panic!("Expected CTERef"),
|
|
},
|
|
_ => panic!("Expected Projection in body"),
|
|
}
|
|
}
|
|
_ => panic!("Expected WithCTE"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_case_expression() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT CASE WHEN age < 18 THEN 'minor' WHEN age < 65 THEN 'adult' ELSE 'senior' END FROM users";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 1);
|
|
assert!(matches!(proj.exprs[0], LogicalExpr::Case { .. }));
|
|
}
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_in_list() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users WHERE id IN (1, 2, 3)";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Filter(filter) => match &filter.predicate {
|
|
LogicalExpr::InList { list, negated, .. } => {
|
|
assert!(!negated);
|
|
assert_eq!(list.len(), 3);
|
|
}
|
|
_ => panic!("Expected InList"),
|
|
},
|
|
_ => panic!("Expected Filter"),
|
|
},
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_in_subquery() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Filter(filter) => {
|
|
assert!(matches!(filter.predicate, LogicalExpr::InSubquery { .. }));
|
|
}
|
|
_ => panic!("Expected Filter"),
|
|
},
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_exists_subquery() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Filter(filter) => {
|
|
assert!(matches!(filter.predicate, LogicalExpr::Exists { .. }));
|
|
}
|
|
_ => panic!("Expected Filter"),
|
|
},
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_between() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users WHERE age BETWEEN 18 AND 65";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Filter(filter) => match &filter.predicate {
|
|
LogicalExpr::Between { negated, .. } => {
|
|
assert!(!negated);
|
|
}
|
|
_ => panic!("Expected Between"),
|
|
},
|
|
_ => panic!("Expected Filter"),
|
|
},
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_like() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users WHERE name LIKE 'John%'";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Filter(filter) => match &filter.predicate {
|
|
LogicalExpr::Like {
|
|
negated, escape, ..
|
|
} => {
|
|
assert!(!negated);
|
|
assert!(escape.is_none());
|
|
}
|
|
_ => panic!("Expected Like"),
|
|
},
|
|
_ => panic!("Expected Filter"),
|
|
},
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_is_null() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users WHERE email IS NULL";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Filter(filter) => match &filter.predicate {
|
|
LogicalExpr::IsNull { negated, .. } => {
|
|
assert!(!negated);
|
|
}
|
|
_ => panic!("Expected IsNull, got: {:?}", filter.predicate),
|
|
},
|
|
_ => panic!("Expected Filter"),
|
|
},
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_is_not_null() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users WHERE email IS NOT NULL";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Filter(filter) => match &filter.predicate {
|
|
LogicalExpr::IsNull { negated, .. } => {
|
|
assert!(negated);
|
|
}
|
|
_ => panic!("Expected IsNull"),
|
|
},
|
|
_ => panic!("Expected Filter"),
|
|
},
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_values_clause() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM (VALUES (1, 'a'), (2, 'b'))";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Values(values) => {
|
|
assert_eq!(values.rows.len(), 2);
|
|
assert_eq!(values.rows[0].len(), 2);
|
|
}
|
|
_ => panic!("Expected Values"),
|
|
},
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_complex_expression_with_aggregation() {
|
|
// Test: SELECT sum(id + 2) * 2 FROM orders GROUP BY user_id
|
|
let schema = create_test_schema();
|
|
|
|
// Test the complex case: sum((id + 2)) * 2 with parentheses
|
|
let sql = "SELECT sum((id + 2)) * 2 FROM orders GROUP BY user_id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 1);
|
|
match &proj.exprs[0] {
|
|
LogicalExpr::BinaryExpr { left, op, right } => {
|
|
assert_eq!(*op, BinaryOperator::Multiply);
|
|
assert!(matches!(**left, LogicalExpr::Column(_)));
|
|
assert!(matches!(**right, LogicalExpr::Literal(_)));
|
|
}
|
|
_ => panic!("Expected BinaryExpr in projection"),
|
|
}
|
|
|
|
match &*proj.input {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.group_expr.len(), 1);
|
|
|
|
assert_eq!(agg.aggr_expr.len(), 1);
|
|
match &agg.aggr_expr[0] {
|
|
LogicalExpr::AggregateFunction { fun, args, .. } => {
|
|
assert_eq!(*fun, AggregateFunction::Sum);
|
|
assert_eq!(args.len(), 1);
|
|
match &args[0] {
|
|
LogicalExpr::Column(col) => {
|
|
assert!(col.name.starts_with("__agg_arg_proj_"));
|
|
}
|
|
_ => panic!("Expected Column reference to projected expression in aggregate args, got {:?}", args[0]),
|
|
}
|
|
}
|
|
_ => panic!("Expected AggregateFunction"),
|
|
}
|
|
|
|
match &*agg.input {
|
|
LogicalPlan::Projection(inner_proj) => {
|
|
assert!(inner_proj.exprs.len() >= 2);
|
|
let has_binary_add = inner_proj.exprs.iter().any(|e| {
|
|
matches!(
|
|
e,
|
|
LogicalExpr::BinaryExpr {
|
|
op: BinaryOperator::Add,
|
|
..
|
|
}
|
|
)
|
|
});
|
|
assert!(
|
|
has_binary_add,
|
|
"Should have id + 2 expression in inner projection"
|
|
);
|
|
}
|
|
_ => panic!("Expected Projection as input to Aggregate"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Aggregate under Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_function_on_aggregate_result() {
|
|
let schema = create_test_schema();
|
|
|
|
let sql = "SELECT abs(sum(id)) FROM orders GROUP BY user_id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 1);
|
|
match &proj.exprs[0] {
|
|
LogicalExpr::ScalarFunction { fun, args } => {
|
|
assert_eq!(fun, "abs");
|
|
assert_eq!(args.len(), 1);
|
|
assert!(matches!(args[0], LogicalExpr::Column(_)));
|
|
}
|
|
_ => panic!("Expected ScalarFunction in projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_multiple_aggregates_with_arithmetic() {
|
|
let schema = create_test_schema();
|
|
|
|
let sql = "SELECT sum(id) * 2 + count(*) FROM orders GROUP BY user_id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 1);
|
|
match &proj.exprs[0] {
|
|
LogicalExpr::BinaryExpr { op, .. } => {
|
|
assert_eq!(*op, BinaryOperator::Add);
|
|
}
|
|
_ => panic!("Expected BinaryExpr"),
|
|
}
|
|
|
|
match &*proj.input {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.aggr_expr.len(), 2);
|
|
}
|
|
_ => panic!("Expected Aggregate"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_projection_aggregation_projection() {
|
|
let schema = create_test_schema();
|
|
|
|
// This tests: projection -> aggregation -> projection
|
|
// The inner projection computes (id + 2), then we aggregate sum(), then apply abs()
|
|
let sql = "SELECT abs(sum(id + 2)) FROM orders GROUP BY user_id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
// Should produce: Projection(abs) -> Aggregate(sum) -> Projection(id + 2) -> TableScan
|
|
match plan {
|
|
LogicalPlan::Projection(outer_proj) => {
|
|
assert_eq!(outer_proj.exprs.len(), 1);
|
|
|
|
// Outer projection should apply abs() function
|
|
match &outer_proj.exprs[0] {
|
|
LogicalExpr::ScalarFunction { fun, args } => {
|
|
assert_eq!(fun, "abs");
|
|
assert_eq!(args.len(), 1);
|
|
assert!(matches!(args[0], LogicalExpr::Column(_)));
|
|
}
|
|
_ => panic!("Expected abs() function in outer projection"),
|
|
}
|
|
|
|
// Next should be the Aggregate
|
|
match &*outer_proj.input {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.group_expr.len(), 1);
|
|
assert_eq!(agg.aggr_expr.len(), 1);
|
|
|
|
// The aggregate should be summing a column reference
|
|
match &agg.aggr_expr[0] {
|
|
LogicalExpr::AggregateFunction { fun, args, .. } => {
|
|
assert_eq!(*fun, AggregateFunction::Sum);
|
|
assert_eq!(args.len(), 1);
|
|
|
|
// Should reference the projected column
|
|
match &args[0] {
|
|
LogicalExpr::Column(col) => {
|
|
assert!(col.name.starts_with("__agg_arg_proj_"));
|
|
}
|
|
_ => panic!("Expected column reference in aggregate"),
|
|
}
|
|
}
|
|
_ => panic!("Expected AggregateFunction"),
|
|
}
|
|
|
|
// Input to aggregate should be a projection computing id + 2
|
|
match &*agg.input {
|
|
LogicalPlan::Projection(inner_proj) => {
|
|
// Should have at least the group column and the computed expression
|
|
assert!(inner_proj.exprs.len() >= 2);
|
|
|
|
// Check for the id + 2 expression
|
|
let has_add_expr = inner_proj.exprs.iter().any(|e| {
|
|
matches!(
|
|
e,
|
|
LogicalExpr::BinaryExpr {
|
|
op: BinaryOperator::Add,
|
|
..
|
|
}
|
|
)
|
|
});
|
|
assert!(
|
|
has_add_expr,
|
|
"Should have id + 2 expression in inner projection"
|
|
);
|
|
}
|
|
_ => panic!("Expected inner Projection under Aggregate"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Aggregate under outer Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_group_by_validation_allow_grouped_column() {
|
|
let schema = create_test_schema();
|
|
|
|
// Test that grouped columns are allowed
|
|
let sql = "SELECT user_id, COUNT(*) FROM orders GROUP BY user_id";
|
|
let result = parse_and_build(sql, &schema);
|
|
|
|
assert!(result.is_ok(), "Should allow grouped column in SELECT");
|
|
}
|
|
|
|
#[test]
|
|
fn test_group_by_validation_allow_constants() {
|
|
let schema = create_test_schema();
|
|
|
|
// Test that simple constants are allowed even when not grouped
|
|
let sql = "SELECT user_id, 42, COUNT(*) FROM orders GROUP BY user_id";
|
|
let result = parse_and_build(sql, &schema);
|
|
|
|
assert!(
|
|
result.is_ok(),
|
|
"Should allow simple constants in SELECT with GROUP BY"
|
|
);
|
|
|
|
let sql_complex = "SELECT user_id, (100 + 50) * 2, COUNT(*) FROM orders GROUP BY user_id";
|
|
let result_complex = parse_and_build(sql_complex, &schema);
|
|
|
|
assert!(
|
|
result_complex.is_ok(),
|
|
"Should allow complex constant expressions in SELECT with GROUP BY"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_parenthesized_aggregate_expressions() {
|
|
let schema = create_test_schema();
|
|
|
|
let sql = "SELECT 25, (MAX(id) / 3), 39 FROM orders";
|
|
let result = parse_and_build(sql, &schema);
|
|
|
|
assert!(
|
|
result.is_ok(),
|
|
"Should handle parenthesized aggregate expressions"
|
|
);
|
|
|
|
let plan = result.unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 3);
|
|
|
|
assert!(matches!(
|
|
proj.exprs[0],
|
|
LogicalExpr::Literal(Value::Integer(25))
|
|
));
|
|
|
|
match &proj.exprs[1] {
|
|
LogicalExpr::BinaryExpr { left, op, right } => {
|
|
assert_eq!(*op, BinaryOperator::Divide);
|
|
assert!(matches!(&**left, LogicalExpr::Column(_)));
|
|
assert!(matches!(&**right, LogicalExpr::Literal(Value::Integer(3))));
|
|
}
|
|
_ => panic!("Expected BinaryExpr for (MAX(id) / 3)"),
|
|
}
|
|
|
|
assert!(matches!(
|
|
proj.exprs[2],
|
|
LogicalExpr::Literal(Value::Integer(39))
|
|
));
|
|
|
|
match &*proj.input {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.aggr_expr.len(), 1);
|
|
assert!(matches!(
|
|
agg.aggr_expr[0],
|
|
LogicalExpr::AggregateFunction {
|
|
fun: AggFunc::Max,
|
|
..
|
|
}
|
|
));
|
|
}
|
|
_ => panic!("Expected Aggregate node under Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_duplicate_aggregate_reuse() {
|
|
let schema = create_test_schema();
|
|
|
|
let sql = "SELECT (COUNT(*) - 225), 30, COUNT(*) FROM orders";
|
|
let result = parse_and_build(sql, &schema);
|
|
|
|
assert!(result.is_ok(), "Should handle duplicate aggregates");
|
|
|
|
let plan = result.unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 3);
|
|
|
|
match &proj.exprs[0] {
|
|
LogicalExpr::BinaryExpr { left, op, right } => {
|
|
assert_eq!(*op, BinaryOperator::Subtract);
|
|
match &**left {
|
|
LogicalExpr::Column(col) => {
|
|
assert!(col.name.starts_with("__agg_") || col.name == "COUNT(*)");
|
|
}
|
|
_ => panic!("Expected Column reference for COUNT(*)"),
|
|
}
|
|
assert!(matches!(
|
|
&**right,
|
|
LogicalExpr::Literal(Value::Integer(225))
|
|
));
|
|
}
|
|
_ => panic!("Expected BinaryExpr for (COUNT(*) - 225)"),
|
|
}
|
|
|
|
assert!(matches!(
|
|
proj.exprs[1],
|
|
LogicalExpr::Literal(Value::Integer(30))
|
|
));
|
|
|
|
match &proj.exprs[2] {
|
|
LogicalExpr::Column(col) => {
|
|
assert!(col.name.starts_with("__agg_") || col.name == "COUNT(*)");
|
|
}
|
|
_ => panic!("Expected Column reference for COUNT(*)"),
|
|
}
|
|
|
|
match &*proj.input {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(
|
|
agg.aggr_expr.len(),
|
|
1,
|
|
"Should have only one COUNT(*) aggregate"
|
|
);
|
|
assert!(matches!(
|
|
agg.aggr_expr[0],
|
|
LogicalExpr::AggregateFunction {
|
|
fun: AggFunc::Count,
|
|
..
|
|
}
|
|
));
|
|
}
|
|
_ => panic!("Expected Aggregate node under Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_aggregate_without_group_by_allow_constants() {
|
|
let schema = create_test_schema();
|
|
|
|
// Test that constants are allowed with aggregates even without GROUP BY
|
|
let sql = "SELECT 42, COUNT(*), MAX(amount) FROM orders";
|
|
let result = parse_and_build(sql, &schema);
|
|
|
|
assert!(
|
|
result.is_ok(),
|
|
"Should allow simple constants with aggregates without GROUP BY"
|
|
);
|
|
|
|
// Test complex constant expressions
|
|
let sql_complex = "SELECT (9 / 6) % 5, COUNT(*), MAX(amount) FROM orders";
|
|
let result_complex = parse_and_build(sql_complex, &schema);
|
|
|
|
assert!(
|
|
result_complex.is_ok(),
|
|
"Should allow complex constant expressions with aggregates without GROUP BY"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_aggregate_without_group_by_creates_aggregate_node() {
|
|
let schema = create_test_schema();
|
|
|
|
// Test that aggregate without GROUP BY creates proper Aggregate node
|
|
let sql = "SELECT MAX(amount) FROM orders";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
// Should be: Aggregate -> TableScan (no projection needed for simple aggregate)
|
|
match plan {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.group_expr.len(), 0, "Should have no group expressions");
|
|
assert_eq!(
|
|
agg.aggr_expr.len(),
|
|
1,
|
|
"Should have one aggregate expression"
|
|
);
|
|
assert_eq!(
|
|
agg.schema.column_count(),
|
|
1,
|
|
"Schema should have one column"
|
|
);
|
|
}
|
|
_ => panic!("Expected Aggregate at top level (no projection)"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_scalar_vs_aggregate_function_classification() {
|
|
let schema = create_test_schema();
|
|
|
|
// Test MIN/MAX with 1 argument - should be aggregate
|
|
let sql = "SELECT MIN(amount) FROM orders";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
match plan {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.aggr_expr.len(), 1, "MIN(x) should be an aggregate");
|
|
match &agg.aggr_expr[0] {
|
|
LogicalExpr::AggregateFunction { fun, args, .. } => {
|
|
assert!(matches!(fun, AggFunc::Min));
|
|
assert_eq!(args.len(), 1);
|
|
}
|
|
_ => panic!("Expected AggregateFunction"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Aggregate node for MIN(x)"),
|
|
}
|
|
|
|
// Test MIN/MAX with 2 arguments - should be scalar in projection
|
|
let sql = "SELECT MIN(amount, user_id) FROM orders";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 1, "Should have one projection expression");
|
|
match &proj.exprs[0] {
|
|
LogicalExpr::ScalarFunction { fun, args } => {
|
|
assert_eq!(
|
|
fun.to_lowercase(),
|
|
"min",
|
|
"MIN(x,y) should be a scalar function"
|
|
);
|
|
assert_eq!(args.len(), 2);
|
|
}
|
|
_ => panic!("Expected ScalarFunction for MIN(x,y)"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection node for scalar MIN(x,y)"),
|
|
}
|
|
|
|
// Test MAX with 3 arguments - should be scalar
|
|
let sql = "SELECT MAX(amount, user_id, id) FROM orders";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 1);
|
|
match &proj.exprs[0] {
|
|
LogicalExpr::ScalarFunction { fun, args } => {
|
|
assert_eq!(
|
|
fun.to_lowercase(),
|
|
"max",
|
|
"MAX(x,y,z) should be a scalar function"
|
|
);
|
|
assert_eq!(args.len(), 3);
|
|
}
|
|
_ => panic!("Expected ScalarFunction for MAX(x,y,z)"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection node for scalar MAX(x,y,z)"),
|
|
}
|
|
|
|
// Test that MIN with 0 args is treated as scalar (will fail later in execution)
|
|
let sql = "SELECT MIN() FROM orders";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &proj.exprs[0] {
|
|
LogicalExpr::ScalarFunction { fun, args } => {
|
|
assert_eq!(fun.to_lowercase(), "min");
|
|
assert_eq!(args.len(), 0, "MIN() should be scalar with 0 args");
|
|
}
|
|
_ => panic!("Expected ScalarFunction for MIN()"),
|
|
},
|
|
_ => panic!("Expected Projection for MIN()"),
|
|
}
|
|
|
|
// Test other functions that are always aggregate (COUNT, SUM, AVG)
|
|
let sql = "SELECT COUNT(*), SUM(amount), AVG(amount) FROM orders";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
match plan {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.aggr_expr.len(), 3, "Should have 3 aggregate functions");
|
|
for expr in &agg.aggr_expr {
|
|
assert!(matches!(expr, LogicalExpr::AggregateFunction { .. }));
|
|
}
|
|
}
|
|
_ => panic!("Expected Aggregate node"),
|
|
}
|
|
|
|
// Test scalar functions that are never aggregates (ABS, ROUND, etc.)
|
|
let sql = "SELECT ABS(amount), ROUND(amount), LENGTH(product) FROM orders";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 3, "Should have 3 scalar functions");
|
|
for expr in &proj.exprs {
|
|
match expr {
|
|
LogicalExpr::ScalarFunction { .. } => {}
|
|
_ => panic!("Expected all ScalarFunctions"),
|
|
}
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection node for scalar functions"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_mixed_aggregate_and_group_columns() {
|
|
let schema = create_test_schema();
|
|
|
|
// When selecting both aggregate and grouping columns
|
|
let sql = "SELECT user_id, sum(id) FROM orders GROUP BY user_id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
// No projection needed - aggregate outputs exactly what we select
|
|
match plan {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.group_expr.len(), 1);
|
|
assert_eq!(agg.aggr_expr.len(), 1);
|
|
assert_eq!(agg.schema.column_count(), 2);
|
|
}
|
|
_ => panic!("Expected Aggregate (no projection)"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_scalar_function_wrapping_aggregate_no_group_by() {
|
|
// Test: SELECT HEX(SUM(age + 2)) FROM users
|
|
// Expected structure:
|
|
// Projection { exprs: [ScalarFunction(HEX, [Column])] }
|
|
// -> Aggregate { aggr_expr: [Sum(BinaryExpr(age + 2))], group_expr: [] }
|
|
// -> Projection { exprs: [BinaryExpr(age + 2)] }
|
|
// -> TableScan("users")
|
|
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT HEX(SUM(age + 2)) FROM users";
|
|
let mut parser = Parser::new(sql.as_bytes());
|
|
let stmt = parser.next().unwrap().unwrap();
|
|
|
|
let plan = match stmt {
|
|
ast::Cmd::Stmt(stmt) => {
|
|
let mut builder = LogicalPlanBuilder::new(&schema);
|
|
builder.build_statement(&stmt).unwrap()
|
|
}
|
|
_ => panic!("Expected SQL statement"),
|
|
};
|
|
|
|
match &plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 1, "Should have one expression");
|
|
|
|
match &proj.exprs[0] {
|
|
LogicalExpr::ScalarFunction { fun, args } => {
|
|
assert_eq!(fun, "HEX", "Outer function should be HEX");
|
|
assert_eq!(args.len(), 1, "HEX should have one argument");
|
|
|
|
match &args[0] {
|
|
LogicalExpr::Column(_) => {}
|
|
LogicalExpr::AggregateFunction { .. } => {
|
|
panic!("Aggregate function should not be embedded in projection! It should be in a separate Aggregate operator");
|
|
}
|
|
_ => panic!(
|
|
"Expected column reference as argument to HEX, got: {:?}",
|
|
args[0]
|
|
),
|
|
}
|
|
}
|
|
_ => panic!("Expected ScalarFunction (HEX), got: {:?}", proj.exprs[0]),
|
|
}
|
|
|
|
match &*proj.input {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.group_expr.len(), 0, "Should have no GROUP BY");
|
|
assert_eq!(
|
|
agg.aggr_expr.len(),
|
|
1,
|
|
"Should have one aggregate expression"
|
|
);
|
|
|
|
match &agg.aggr_expr[0] {
|
|
LogicalExpr::AggregateFunction {
|
|
fun,
|
|
args,
|
|
distinct,
|
|
} => {
|
|
assert_eq!(*fun, crate::function::AggFunc::Sum, "Should be SUM");
|
|
assert!(!distinct, "Should not be DISTINCT");
|
|
assert_eq!(args.len(), 1, "SUM should have one argument");
|
|
|
|
match &args[0] {
|
|
LogicalExpr::Column(col) => {
|
|
// When aggregate arguments are complex, they get pre-projected
|
|
assert!(col.name.starts_with("__agg_arg_proj_"),
|
|
"Should reference pre-projected column, got: {}", col.name);
|
|
}
|
|
LogicalExpr::BinaryExpr { left, op, right } => {
|
|
// Simple case without pre-projection (shouldn't happen with current implementation)
|
|
assert_eq!(*op, ast::Operator::Add, "Should be addition");
|
|
|
|
match (&**left, &**right) {
|
|
(LogicalExpr::Column(col), LogicalExpr::Literal(val)) => {
|
|
assert_eq!(col.name, "age", "Should reference age column");
|
|
assert_eq!(*val, Value::Integer(2), "Should add 2");
|
|
}
|
|
_ => panic!("Expected age + 2"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Column reference or BinaryExpr for aggregate argument, got: {:?}", args[0]),
|
|
}
|
|
}
|
|
_ => panic!("Expected AggregateFunction"),
|
|
}
|
|
|
|
match &*agg.input {
|
|
LogicalPlan::TableScan(scan) => {
|
|
assert_eq!(scan.table_name, "users");
|
|
}
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::TableScan(scan) => {
|
|
assert_eq!(scan.table_name, "users");
|
|
}
|
|
_ => panic!("Expected TableScan under projection"),
|
|
},
|
|
_ => panic!("Expected TableScan or Projection under Aggregate"),
|
|
}
|
|
}
|
|
_ => panic!(
|
|
"Expected Aggregate operator under Projection, got: {:?}",
|
|
proj.input
|
|
),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection as top-level operator, got: {plan:?}"),
|
|
}
|
|
}
|
|
|
|
// ===== JOIN TESTS =====
|
|
|
|
#[test]
|
|
fn test_inner_join() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users u INNER JOIN orders o ON u.id = o.user_id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
match &*proj.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Inner);
|
|
assert!(!join.on.is_empty(), "Should have join conditions");
|
|
|
|
// Check left input is users
|
|
match &*join.left {
|
|
LogicalPlan::TableScan(scan) => {
|
|
assert_eq!(scan.table_name, "users");
|
|
}
|
|
_ => panic!("Expected TableScan for left input"),
|
|
}
|
|
|
|
// Check right input is orders
|
|
match &*join.right {
|
|
LogicalPlan::TableScan(scan) => {
|
|
assert_eq!(scan.table_name, "orders");
|
|
}
|
|
_ => panic!("Expected TableScan for right input"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_left_join() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 2); // name and amount
|
|
match &*proj.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Left);
|
|
assert!(!join.on.is_empty(), "Should have join conditions");
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_right_join() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM orders o RIGHT JOIN users u ON o.user_id = u.id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Right);
|
|
assert!(!join.on.is_empty(), "Should have join conditions");
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
},
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_full_outer_join() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users u FULL OUTER JOIN orders o ON u.id = o.user_id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Full);
|
|
assert!(!join.on.is_empty(), "Should have join conditions");
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
},
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_cross_join() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users CROSS JOIN orders";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Cross);
|
|
assert!(join.on.is_empty(), "Cross join should have no conditions");
|
|
assert!(join.filter.is_none(), "Cross join should have no filter");
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
},
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_join_with_multiple_conditions() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id AND u.age > 18";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
match &*proj.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Inner);
|
|
// Should have at least one equijoin condition
|
|
assert!(!join.on.is_empty(), "Should have join conditions");
|
|
// Additional conditions may be in filter
|
|
// The exact distribution depends on our implementation
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_join_using_clause() {
|
|
let schema = create_test_schema();
|
|
// Note: Both tables should have an 'id' column for this to work
|
|
let sql = "SELECT * FROM users JOIN orders USING (id)";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Inner);
|
|
assert!(
|
|
!join.on.is_empty(),
|
|
"USING clause should create join conditions"
|
|
);
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
},
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_natural_join() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users NATURAL JOIN orders";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
match &*proj.input {
|
|
LogicalPlan::Join(join) => {
|
|
// Natural join finds common columns (id in this case)
|
|
// If no common columns, it becomes a cross join
|
|
assert!(
|
|
!join.on.is_empty() || join.join_type == JoinType::Cross,
|
|
"Natural join should either find common columns or become cross join"
|
|
);
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_three_way_join() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users u
|
|
JOIN orders o ON u.id = o.user_id
|
|
JOIN products p ON o.product_id = p.id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
match &*proj.input {
|
|
LogicalPlan::Join(join2) => {
|
|
// Second join (with products)
|
|
assert_eq!(join2.join_type, JoinType::Inner);
|
|
match &*join2.left {
|
|
LogicalPlan::Join(join1) => {
|
|
// First join (users with orders)
|
|
assert_eq!(join1.join_type, JoinType::Inner);
|
|
}
|
|
_ => panic!("Expected nested Join for three-way join"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_mixed_join_types() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users u
|
|
LEFT JOIN orders o ON u.id = o.user_id
|
|
INNER JOIN products p ON o.product_id = p.id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
match &*proj.input {
|
|
LogicalPlan::Join(join2) => {
|
|
// Second join should be INNER
|
|
assert_eq!(join2.join_type, JoinType::Inner);
|
|
match &*join2.left {
|
|
LogicalPlan::Join(join1) => {
|
|
// First join should be LEFT
|
|
assert_eq!(join1.join_type, JoinType::Left);
|
|
}
|
|
_ => panic!("Expected nested Join"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_join_with_filter() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id WHERE o.amount > 100";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
match &*proj.input {
|
|
LogicalPlan::Filter(filter) => {
|
|
// WHERE clause creates a Filter above the Join
|
|
match &*filter.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Inner);
|
|
}
|
|
_ => panic!("Expected Join under Filter"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Filter under Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_join_with_projection() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT u.name, o.amount FROM users u JOIN orders o ON u.id = o.user_id";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(proj) => {
|
|
assert_eq!(proj.exprs.len(), 2); // u.name and o.amount
|
|
match &*proj.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Inner);
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_join_with_aggregation() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT u.name, SUM(o.amount)
|
|
FROM users u JOIN orders o ON u.id = o.user_id
|
|
GROUP BY u.name";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Aggregate(agg) => {
|
|
assert_eq!(agg.group_expr.len(), 1); // GROUP BY u.name
|
|
assert_eq!(agg.aggr_expr.len(), 1); // SUM(o.amount)
|
|
match &*agg.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Inner);
|
|
}
|
|
_ => panic!("Expected Join under Aggregate"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Aggregate"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_join_with_order_by() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id ORDER BY o.amount DESC";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Sort(sort) => {
|
|
assert_eq!(sort.exprs.len(), 1);
|
|
assert!(!sort.exprs[0].asc); // DESC
|
|
match &*sort.input {
|
|
LogicalPlan::Projection(proj) => match &*proj.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Inner);
|
|
}
|
|
_ => panic!("Expected Join under Projection"),
|
|
},
|
|
_ => panic!("Expected Projection under Sort"),
|
|
}
|
|
}
|
|
_ => panic!("Expected Sort at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_join_in_subquery() {
|
|
let schema = create_test_schema();
|
|
let sql = "SELECT * FROM (
|
|
SELECT u.id, u.name, o.amount
|
|
FROM users u JOIN orders o ON u.id = o.user_id
|
|
) WHERE amount > 100";
|
|
let plan = parse_and_build(sql, &schema).unwrap();
|
|
|
|
match plan {
|
|
LogicalPlan::Projection(outer_proj) => match &*outer_proj.input {
|
|
LogicalPlan::Filter(filter) => match &*filter.input {
|
|
LogicalPlan::Projection(inner_proj) => match &*inner_proj.input {
|
|
LogicalPlan::Join(join) => {
|
|
assert_eq!(join.join_type, JoinType::Inner);
|
|
}
|
|
_ => panic!("Expected Join in subquery"),
|
|
},
|
|
_ => panic!("Expected Projection for subquery"),
|
|
},
|
|
_ => panic!("Expected Filter"),
|
|
},
|
|
_ => panic!("Expected Projection at top level"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_join_ambiguous_column() {
|
|
let schema = create_test_schema();
|
|
// Both users and orders have an 'id' column
|
|
let sql = "SELECT id FROM users JOIN orders ON users.id = orders.user_id";
|
|
let result = parse_and_build(sql, &schema);
|
|
// This might error or succeed depending on how we handle ambiguous columns
|
|
// For now, just check that parsing completes
|
|
match result {
|
|
Ok(_) => {
|
|
// If successful, the implementation handles ambiguous columns somehow
|
|
}
|
|
Err(_) => {
|
|
// If error, the implementation rejects ambiguous columns
|
|
}
|
|
}
|
|
}
|
|
}
|