add joins to the logical plan

This commit is contained in:
Glauber Costa
2025-09-04 19:35:49 -05:00
parent 5b4a6e5c2d
commit 2e7a45559b

View File

@@ -25,6 +25,9 @@ type PreprocessAggregateResult = (
Vec<LogicalExpr>, // modified_aggr_exprs
);
/// Result type for parsing join conditions
type JoinConditionsResult = (Vec<(LogicalExpr, LogicalExpr)>, Option<LogicalExpr>);
/// Schema information for logical plan nodes
#[derive(Debug, Clone, PartialEq)]
pub struct LogicalSchema {
@@ -66,8 +69,8 @@ pub enum LogicalPlan {
Filter(Filter),
/// Aggregate - GROUP BY with aggregate functions
Aggregate(Aggregate),
// TODO: Join - combining two relations (not yet implemented)
// Join(Join),
/// Join - combining two relations
Join(Join),
/// Sort - ORDER BY clause
Sort(Sort),
/// Limit - LIMIT/OFFSET clause
@@ -95,7 +98,7 @@ impl LogicalPlan {
LogicalPlan::Projection(p) => &p.schema,
LogicalPlan::Filter(f) => f.input.schema(),
LogicalPlan::Aggregate(a) => &a.schema,
// LogicalPlan::Join(j) => &j.schema,
LogicalPlan::Join(j) => &j.schema,
LogicalPlan::Sort(s) => s.input.schema(),
LogicalPlan::Limit(l) => l.input.schema(),
LogicalPlan::TableScan(t) => &t.schema,
@@ -133,26 +136,26 @@ pub struct Aggregate {
pub schema: SchemaRef,
}
// TODO: Join operator (not yet implemented)
// #[derive(Debug, Clone, PartialEq)]
// pub struct Join {
// pub left: Arc<LogicalPlan>,
// pub right: Arc<LogicalPlan>,
// pub join_type: JoinType,
// pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions
// pub filter: Option<LogicalExpr>, // Additional filter conditions
// pub schema: SchemaRef,
// }
/// Types of joins
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum JoinType {
Inner,
Left,
Right,
Full,
Cross,
}
// TODO: Types of joins (not yet implemented)
// #[derive(Debug, Clone, Copy, PartialEq)]
// pub enum JoinType {
// Inner,
// Left,
// Right,
// Full,
// Cross,
// }
/// Join operator - combines two relations
#[derive(Debug, Clone, PartialEq)]
pub struct Join {
pub left: Arc<LogicalPlan>,
pub right: Arc<LogicalPlan>,
pub join_type: JoinType,
pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions (left_expr, right_expr)
pub filter: Option<LogicalExpr>, // Additional filter conditions
pub schema: SchemaRef,
}
/// Sort operator - ORDER BY
#[derive(Debug, Clone, PartialEq)]
@@ -570,14 +573,279 @@ impl<'a> LogicalPlanBuilder<'a> {
// Build JOIN
fn build_join(
&mut self,
_left: LogicalPlan,
_right: LogicalPlan,
_op: &ast::JoinOperator,
_constraint: &Option<ast::JoinConstraint>,
left: LogicalPlan,
right: LogicalPlan,
op: &ast::JoinOperator,
constraint: &Option<ast::JoinConstraint>,
) -> Result<LogicalPlan> {
Err(LimboError::ParseError(
"JOINs are not yet supported in logical plans".to_string(),
))
// Determine join type
let join_type = match op {
ast::JoinOperator::Comma => JoinType::Cross, // Comma is essentially a cross join
ast::JoinOperator::TypedJoin(Some(jt)) => {
// Check the join type flags
// Note: JoinType can have multiple flags set
if jt.contains(ast::JoinType::NATURAL) {
// Natural joins need special handling - find common columns
return self.build_natural_join(left, right, JoinType::Inner);
} else if jt.contains(ast::JoinType::LEFT)
&& jt.contains(ast::JoinType::RIGHT)
&& jt.contains(ast::JoinType::OUTER)
{
// FULL OUTER JOIN (has LEFT, RIGHT, and OUTER)
JoinType::Full
} else if jt.contains(ast::JoinType::LEFT) && jt.contains(ast::JoinType::OUTER) {
JoinType::Left
} else if jt.contains(ast::JoinType::RIGHT) && jt.contains(ast::JoinType::OUTER) {
JoinType::Right
} else if jt.contains(ast::JoinType::OUTER)
&& !jt.contains(ast::JoinType::LEFT)
&& !jt.contains(ast::JoinType::RIGHT)
{
// Plain OUTER JOIN should also be FULL
JoinType::Full
} else if jt.contains(ast::JoinType::LEFT) {
JoinType::Left
} else if jt.contains(ast::JoinType::RIGHT) {
JoinType::Right
} else if jt.contains(ast::JoinType::CROSS)
|| (jt.contains(ast::JoinType::INNER) && jt.contains(ast::JoinType::CROSS))
{
JoinType::Cross
} else {
JoinType::Inner // Default to inner
}
}
ast::JoinOperator::TypedJoin(None) => JoinType::Inner, // Default JOIN is INNER JOIN
};
// Build join conditions
let (on_conditions, filter) = match constraint {
Some(ast::JoinConstraint::On(expr)) => {
// Parse ON clause into equijoin conditions and filters
self.parse_join_conditions(expr, left.schema(), right.schema())?
}
Some(ast::JoinConstraint::Using(columns)) => {
// Build equijoin conditions from USING clause
let on = self.build_using_conditions(columns, left.schema(), right.schema())?;
(on, None)
}
None => {
// Cross join or natural join
(Vec::new(), None)
}
};
// Build combined schema
let schema = self.build_join_schema(&left, &right, &join_type)?;
Ok(LogicalPlan::Join(Join {
left: Arc::new(left),
right: Arc::new(right),
join_type,
on: on_conditions,
filter,
schema,
}))
}
// Helper: Parse join conditions into equijoins and filters
fn parse_join_conditions(
&mut self,
expr: &ast::Expr,
left_schema: &SchemaRef,
right_schema: &SchemaRef,
) -> Result<JoinConditionsResult> {
// For now, we'll handle simple equality conditions
// More complex conditions will go into the filter
let mut equijoins = Vec::new();
let mut filters = Vec::new();
// Try to extract equijoin conditions from the expression
self.extract_equijoin_conditions(
expr,
left_schema,
right_schema,
&mut equijoins,
&mut filters,
)?;
let filter = if filters.is_empty() {
None
} else {
// Combine multiple filters with AND
Some(
filters
.into_iter()
.reduce(|acc, e| LogicalExpr::BinaryExpr {
left: Box::new(acc),
op: BinaryOperator::And,
right: Box::new(e),
})
.unwrap(),
)
};
Ok((equijoins, filter))
}
// Helper: Extract equijoin conditions from expression
fn extract_equijoin_conditions(
&mut self,
expr: &ast::Expr,
left_schema: &SchemaRef,
right_schema: &SchemaRef,
equijoins: &mut Vec<(LogicalExpr, LogicalExpr)>,
filters: &mut Vec<LogicalExpr>,
) -> Result<()> {
match expr {
ast::Expr::Binary(lhs, ast::Operator::Equals, rhs) => {
// Check if this is an equijoin condition (left.col = right.col)
let left_expr = self.build_expr(lhs, left_schema)?;
let right_expr = self.build_expr(rhs, right_schema)?;
// For simplicity, we'll check if one references left and one references right
// In a real implementation, we'd need more sophisticated column resolution
equijoins.push((left_expr, right_expr));
}
ast::Expr::Binary(lhs, ast::Operator::And, rhs) => {
// Recursively extract from AND conditions
self.extract_equijoin_conditions(
lhs,
left_schema,
right_schema,
equijoins,
filters,
)?;
self.extract_equijoin_conditions(
rhs,
left_schema,
right_schema,
equijoins,
filters,
)?;
}
_ => {
// Other conditions go into the filter
// We need a combined schema to build the expression
let combined_schema = self.combine_schemas(left_schema, right_schema)?;
let filter_expr = self.build_expr(expr, &combined_schema)?;
filters.push(filter_expr);
}
}
Ok(())
}
// Helper: Build equijoin conditions from USING clause
fn build_using_conditions(
&mut self,
columns: &[ast::Name],
left_schema: &SchemaRef,
right_schema: &SchemaRef,
) -> Result<Vec<(LogicalExpr, LogicalExpr)>> {
let mut conditions = Vec::new();
for col_name in columns {
let name = Self::name_to_string(col_name);
// Find the column in both schemas
let _left_idx = left_schema
.columns
.iter()
.position(|(n, _)| n == &name)
.ok_or_else(|| {
LimboError::ParseError(format!("Column {name} not found in left table"))
})?;
let _right_idx = right_schema
.columns
.iter()
.position(|(n, _)| n == &name)
.ok_or_else(|| {
LimboError::ParseError(format!("Column {name} not found in right table"))
})?;
conditions.push((
LogicalExpr::Column(Column {
name: name.clone(),
table: None, // Will be resolved later
}),
LogicalExpr::Column(Column {
name,
table: None, // Will be resolved later
}),
));
}
Ok(conditions)
}
// Helper: Build natural join by finding common columns
fn build_natural_join(
&mut self,
left: LogicalPlan,
right: LogicalPlan,
join_type: JoinType,
) -> Result<LogicalPlan> {
let left_schema = left.schema();
let right_schema = right.schema();
// Find common column names
let mut common_columns = Vec::new();
for (left_name, _) in &left_schema.columns {
if right_schema.columns.iter().any(|(n, _)| n == left_name) {
common_columns.push(ast::Name::Ident(left_name.clone()));
}
}
if common_columns.is_empty() {
// Natural join with no common columns becomes a cross join
let schema = self.build_join_schema(&left, &right, &JoinType::Cross)?;
return Ok(LogicalPlan::Join(Join {
left: Arc::new(left),
right: Arc::new(right),
join_type: JoinType::Cross,
on: Vec::new(),
filter: None,
schema,
}));
}
// Build equijoin conditions for common columns
let on = self.build_using_conditions(&common_columns, left_schema, right_schema)?;
let schema = self.build_join_schema(&left, &right, &join_type)?;
Ok(LogicalPlan::Join(Join {
left: Arc::new(left),
right: Arc::new(right),
join_type,
on,
filter: None,
schema,
}))
}
// Helper: Build schema for join result
fn build_join_schema(
&self,
left: &LogicalPlan,
right: &LogicalPlan,
_join_type: &JoinType,
) -> Result<SchemaRef> {
let left_schema = left.schema();
let right_schema = right.schema();
// For now, simply concatenate the schemas
// In a real implementation, we'd handle column name conflicts and nullable columns
let mut columns = left_schema.columns.clone();
columns.extend(right_schema.columns.clone());
Ok(Arc::new(LogicalSchema::new(columns)))
}
// Helper: Combine two schemas for expression building
fn combine_schemas(&self, left: &SchemaRef, right: &SchemaRef) -> Result<SchemaRef> {
let mut columns = left.columns.clone();
columns.extend(right.columns.clone());
Ok(Arc::new(LogicalSchema::new(columns)))
}
// Build projection
@@ -1974,6 +2242,67 @@ mod tests {
};
schema.add_btree_table(Arc::new(orders_table));
// Create products table
let products_table = BTreeTable {
name: "products".to_string(),
root_page: 4,
primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)],
columns: vec![
SchemaColumn {
name: Some("id".to_string()),
ty: Type::Integer,
ty_str: "INTEGER".to_string(),
primary_key: true,
is_rowid_alias: true,
notnull: true,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("name".to_string()),
ty: Type::Text,
ty_str: "TEXT".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("price".to_string()),
ty: Type::Real,
ty_str: "REAL".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
SchemaColumn {
name: Some("product_id".to_string()),
ty: Type::Integer,
ty_str: "INTEGER".to_string(),
primary_key: false,
is_rowid_alias: false,
notnull: false,
default: None,
unique: false,
collation: None,
hidden: false,
},
],
has_rowid: true,
is_strict: false,
unique_sets: vec![],
};
schema.add_btree_table(Arc::new(products_table));
schema
}
@@ -3086,4 +3415,381 @@ mod tests {
_ => panic!("Expected Projection as top-level operator, got: {plan:?}"),
}
}
// ===== JOIN TESTS =====
#[test]
fn test_inner_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u INNER JOIN orders o ON u.id = o.user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
assert!(!join.on.is_empty(), "Should have join conditions");
// Check left input is users
match &*join.left {
LogicalPlan::TableScan(scan) => {
assert_eq!(scan.table_name, "users");
}
_ => panic!("Expected TableScan for left input"),
}
// Check right input is orders
match &*join.right {
LogicalPlan::TableScan(scan) => {
assert_eq!(scan.table_name, "orders");
}
_ => panic!("Expected TableScan for right input"),
}
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_left_join() {
let schema = create_test_schema();
let sql = "SELECT u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 2); // name and amount
match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Left);
assert!(!join.on.is_empty(), "Should have join conditions");
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_right_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM orders o RIGHT JOIN users u ON o.user_id = u.id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Right);
assert!(!join.on.is_empty(), "Should have join conditions");
}
_ => panic!("Expected Join under Projection"),
},
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_full_outer_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u FULL OUTER JOIN orders o ON u.id = o.user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Full);
assert!(!join.on.is_empty(), "Should have join conditions");
}
_ => panic!("Expected Join under Projection"),
},
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_cross_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM users CROSS JOIN orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Cross);
assert!(join.on.is_empty(), "Cross join should have no conditions");
assert!(join.filter.is_none(), "Cross join should have no filter");
}
_ => panic!("Expected Join under Projection"),
},
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_with_multiple_conditions() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id AND u.age > 18";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
// Should have at least one equijoin condition
assert!(!join.on.is_empty(), "Should have join conditions");
// Additional conditions may be in filter
// The exact distribution depends on our implementation
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_using_clause() {
let schema = create_test_schema();
// Note: Both tables should have an 'id' column for this to work
let sql = "SELECT * FROM users JOIN orders USING (id)";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
assert!(
!join.on.is_empty(),
"USING clause should create join conditions"
);
}
_ => panic!("Expected Join under Projection"),
},
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_natural_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM users NATURAL JOIN orders";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Join(join) => {
// Natural join finds common columns (id in this case)
// If no common columns, it becomes a cross join
assert!(
!join.on.is_empty() || join.join_type == JoinType::Cross,
"Natural join should either find common columns or become cross join"
);
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_three_way_join() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u
JOIN orders o ON u.id = o.user_id
JOIN products p ON o.product_id = p.id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Join(join2) => {
// Second join (with products)
assert_eq!(join2.join_type, JoinType::Inner);
match &*join2.left {
LogicalPlan::Join(join1) => {
// First join (users with orders)
assert_eq!(join1.join_type, JoinType::Inner);
}
_ => panic!("Expected nested Join for three-way join"),
}
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_mixed_join_types() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u
LEFT JOIN orders o ON u.id = o.user_id
INNER JOIN products p ON o.product_id = p.id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Join(join2) => {
// Second join should be INNER
assert_eq!(join2.join_type, JoinType::Inner);
match &*join2.left {
LogicalPlan::Join(join1) => {
// First join should be LEFT
assert_eq!(join1.join_type, JoinType::Left);
}
_ => panic!("Expected nested Join"),
}
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_with_filter() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id WHERE o.amount > 100";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
match &*proj.input {
LogicalPlan::Filter(filter) => {
// WHERE clause creates a Filter above the Join
match &*filter.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
}
_ => panic!("Expected Join under Filter"),
}
}
_ => panic!("Expected Filter under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_with_projection() {
let schema = create_test_schema();
let sql = "SELECT u.name, o.amount FROM users u JOIN orders o ON u.id = o.user_id";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(proj) => {
assert_eq!(proj.exprs.len(), 2); // u.name and o.amount
match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
}
_ => panic!("Expected Join under Projection"),
}
}
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_with_aggregation() {
let schema = create_test_schema();
let sql = "SELECT u.name, SUM(o.amount)
FROM users u JOIN orders o ON u.id = o.user_id
GROUP BY u.name";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Aggregate(agg) => {
assert_eq!(agg.group_expr.len(), 1); // GROUP BY u.name
assert_eq!(agg.aggr_expr.len(), 1); // SUM(o.amount)
match &*agg.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
}
_ => panic!("Expected Join under Aggregate"),
}
}
_ => panic!("Expected Aggregate"),
}
}
#[test]
fn test_join_with_order_by() {
let schema = create_test_schema();
let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id ORDER BY o.amount DESC";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Sort(sort) => {
assert_eq!(sort.exprs.len(), 1);
assert!(!sort.exprs[0].asc); // DESC
match &*sort.input {
LogicalPlan::Projection(proj) => match &*proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
}
_ => panic!("Expected Join under Projection"),
},
_ => panic!("Expected Projection under Sort"),
}
}
_ => panic!("Expected Sort at top level"),
}
}
#[test]
fn test_join_in_subquery() {
let schema = create_test_schema();
let sql = "SELECT * FROM (
SELECT u.id, u.name, o.amount
FROM users u JOIN orders o ON u.id = o.user_id
) WHERE amount > 100";
let plan = parse_and_build(sql, &schema).unwrap();
match plan {
LogicalPlan::Projection(outer_proj) => match &*outer_proj.input {
LogicalPlan::Filter(filter) => match &*filter.input {
LogicalPlan::Projection(inner_proj) => match &*inner_proj.input {
LogicalPlan::Join(join) => {
assert_eq!(join.join_type, JoinType::Inner);
}
_ => panic!("Expected Join in subquery"),
},
_ => panic!("Expected Projection for subquery"),
},
_ => panic!("Expected Filter"),
},
_ => panic!("Expected Projection at top level"),
}
}
#[test]
fn test_join_ambiguous_column() {
let schema = create_test_schema();
// Both users and orders have an 'id' column
let sql = "SELECT id FROM users JOIN orders ON users.id = orders.user_id";
let result = parse_and_build(sql, &schema);
// This might error or succeed depending on how we handle ambiguous columns
// For now, just check that parsing completes
match result {
Ok(_) => {
// If successful, the implementation handles ambiguous columns somehow
}
Err(_) => {
// If error, the implementation rejects ambiguous columns
}
}
}
}