diff --git a/core/translate/logical.rs b/core/translate/logical.rs index aa71f047b..6a8b0a6c2 100644 --- a/core/translate/logical.rs +++ b/core/translate/logical.rs @@ -25,6 +25,9 @@ type PreprocessAggregateResult = ( Vec, // modified_aggr_exprs ); +/// Result type for parsing join conditions +type JoinConditionsResult = (Vec<(LogicalExpr, LogicalExpr)>, Option); + /// Schema information for logical plan nodes #[derive(Debug, Clone, PartialEq)] pub struct LogicalSchema { @@ -66,8 +69,8 @@ pub enum LogicalPlan { Filter(Filter), /// Aggregate - GROUP BY with aggregate functions Aggregate(Aggregate), - // TODO: Join - combining two relations (not yet implemented) - // Join(Join), + /// Join - combining two relations + Join(Join), /// Sort - ORDER BY clause Sort(Sort), /// Limit - LIMIT/OFFSET clause @@ -95,7 +98,7 @@ impl LogicalPlan { LogicalPlan::Projection(p) => &p.schema, LogicalPlan::Filter(f) => f.input.schema(), LogicalPlan::Aggregate(a) => &a.schema, - // LogicalPlan::Join(j) => &j.schema, + LogicalPlan::Join(j) => &j.schema, LogicalPlan::Sort(s) => s.input.schema(), LogicalPlan::Limit(l) => l.input.schema(), LogicalPlan::TableScan(t) => &t.schema, @@ -133,26 +136,26 @@ pub struct Aggregate { pub schema: SchemaRef, } -// TODO: Join operator (not yet implemented) -// #[derive(Debug, Clone, PartialEq)] -// pub struct Join { -// pub left: Arc, -// pub right: Arc, -// pub join_type: JoinType, -// pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions -// pub filter: Option, // Additional filter conditions -// pub schema: SchemaRef, -// } +/// Types of joins +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum JoinType { + Inner, + Left, + Right, + Full, + Cross, +} -// TODO: Types of joins (not yet implemented) -// #[derive(Debug, Clone, Copy, PartialEq)] -// pub enum JoinType { -// Inner, -// Left, -// Right, -// Full, -// Cross, -// } +/// Join operator - combines two relations +#[derive(Debug, Clone, PartialEq)] +pub struct Join { + pub left: Arc, + pub right: Arc, + pub join_type: JoinType, + pub on: Vec<(LogicalExpr, LogicalExpr)>, // Equijoin conditions (left_expr, right_expr) + pub filter: Option, // Additional filter conditions + pub schema: SchemaRef, +} /// Sort operator - ORDER BY #[derive(Debug, Clone, PartialEq)] @@ -570,14 +573,279 @@ impl<'a> LogicalPlanBuilder<'a> { // Build JOIN fn build_join( &mut self, - _left: LogicalPlan, - _right: LogicalPlan, - _op: &ast::JoinOperator, - _constraint: &Option, + left: LogicalPlan, + right: LogicalPlan, + op: &ast::JoinOperator, + constraint: &Option, ) -> Result { - Err(LimboError::ParseError( - "JOINs are not yet supported in logical plans".to_string(), - )) + // Determine join type + let join_type = match op { + ast::JoinOperator::Comma => JoinType::Cross, // Comma is essentially a cross join + ast::JoinOperator::TypedJoin(Some(jt)) => { + // Check the join type flags + // Note: JoinType can have multiple flags set + if jt.contains(ast::JoinType::NATURAL) { + // Natural joins need special handling - find common columns + return self.build_natural_join(left, right, JoinType::Inner); + } else if jt.contains(ast::JoinType::LEFT) + && jt.contains(ast::JoinType::RIGHT) + && jt.contains(ast::JoinType::OUTER) + { + // FULL OUTER JOIN (has LEFT, RIGHT, and OUTER) + JoinType::Full + } else if jt.contains(ast::JoinType::LEFT) && jt.contains(ast::JoinType::OUTER) { + JoinType::Left + } else if jt.contains(ast::JoinType::RIGHT) && jt.contains(ast::JoinType::OUTER) { + JoinType::Right + } else if jt.contains(ast::JoinType::OUTER) + && !jt.contains(ast::JoinType::LEFT) + && !jt.contains(ast::JoinType::RIGHT) + { + // Plain OUTER JOIN should also be FULL + JoinType::Full + } else if jt.contains(ast::JoinType::LEFT) { + JoinType::Left + } else if jt.contains(ast::JoinType::RIGHT) { + JoinType::Right + } else if jt.contains(ast::JoinType::CROSS) + || (jt.contains(ast::JoinType::INNER) && jt.contains(ast::JoinType::CROSS)) + { + JoinType::Cross + } else { + JoinType::Inner // Default to inner + } + } + ast::JoinOperator::TypedJoin(None) => JoinType::Inner, // Default JOIN is INNER JOIN + }; + + // Build join conditions + let (on_conditions, filter) = match constraint { + Some(ast::JoinConstraint::On(expr)) => { + // Parse ON clause into equijoin conditions and filters + self.parse_join_conditions(expr, left.schema(), right.schema())? + } + Some(ast::JoinConstraint::Using(columns)) => { + // Build equijoin conditions from USING clause + let on = self.build_using_conditions(columns, left.schema(), right.schema())?; + (on, None) + } + None => { + // Cross join or natural join + (Vec::new(), None) + } + }; + + // Build combined schema + let schema = self.build_join_schema(&left, &right, &join_type)?; + + Ok(LogicalPlan::Join(Join { + left: Arc::new(left), + right: Arc::new(right), + join_type, + on: on_conditions, + filter, + schema, + })) + } + + // Helper: Parse join conditions into equijoins and filters + fn parse_join_conditions( + &mut self, + expr: &ast::Expr, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + ) -> Result { + // For now, we'll handle simple equality conditions + // More complex conditions will go into the filter + let mut equijoins = Vec::new(); + let mut filters = Vec::new(); + + // Try to extract equijoin conditions from the expression + self.extract_equijoin_conditions( + expr, + left_schema, + right_schema, + &mut equijoins, + &mut filters, + )?; + + let filter = if filters.is_empty() { + None + } else { + // Combine multiple filters with AND + Some( + filters + .into_iter() + .reduce(|acc, e| LogicalExpr::BinaryExpr { + left: Box::new(acc), + op: BinaryOperator::And, + right: Box::new(e), + }) + .unwrap(), + ) + }; + + Ok((equijoins, filter)) + } + + // Helper: Extract equijoin conditions from expression + fn extract_equijoin_conditions( + &mut self, + expr: &ast::Expr, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + equijoins: &mut Vec<(LogicalExpr, LogicalExpr)>, + filters: &mut Vec, + ) -> Result<()> { + match expr { + ast::Expr::Binary(lhs, ast::Operator::Equals, rhs) => { + // Check if this is an equijoin condition (left.col = right.col) + let left_expr = self.build_expr(lhs, left_schema)?; + let right_expr = self.build_expr(rhs, right_schema)?; + + // For simplicity, we'll check if one references left and one references right + // In a real implementation, we'd need more sophisticated column resolution + equijoins.push((left_expr, right_expr)); + } + ast::Expr::Binary(lhs, ast::Operator::And, rhs) => { + // Recursively extract from AND conditions + self.extract_equijoin_conditions( + lhs, + left_schema, + right_schema, + equijoins, + filters, + )?; + self.extract_equijoin_conditions( + rhs, + left_schema, + right_schema, + equijoins, + filters, + )?; + } + _ => { + // Other conditions go into the filter + // We need a combined schema to build the expression + let combined_schema = self.combine_schemas(left_schema, right_schema)?; + let filter_expr = self.build_expr(expr, &combined_schema)?; + filters.push(filter_expr); + } + } + Ok(()) + } + + // Helper: Build equijoin conditions from USING clause + fn build_using_conditions( + &mut self, + columns: &[ast::Name], + left_schema: &SchemaRef, + right_schema: &SchemaRef, + ) -> Result> { + let mut conditions = Vec::new(); + + for col_name in columns { + let name = Self::name_to_string(col_name); + + // Find the column in both schemas + let _left_idx = left_schema + .columns + .iter() + .position(|(n, _)| n == &name) + .ok_or_else(|| { + LimboError::ParseError(format!("Column {name} not found in left table")) + })?; + let _right_idx = right_schema + .columns + .iter() + .position(|(n, _)| n == &name) + .ok_or_else(|| { + LimboError::ParseError(format!("Column {name} not found in right table")) + })?; + + conditions.push(( + LogicalExpr::Column(Column { + name: name.clone(), + table: None, // Will be resolved later + }), + LogicalExpr::Column(Column { + name, + table: None, // Will be resolved later + }), + )); + } + + Ok(conditions) + } + + // Helper: Build natural join by finding common columns + fn build_natural_join( + &mut self, + left: LogicalPlan, + right: LogicalPlan, + join_type: JoinType, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + // Find common column names + let mut common_columns = Vec::new(); + for (left_name, _) in &left_schema.columns { + if right_schema.columns.iter().any(|(n, _)| n == left_name) { + common_columns.push(ast::Name::Ident(left_name.clone())); + } + } + + if common_columns.is_empty() { + // Natural join with no common columns becomes a cross join + let schema = self.build_join_schema(&left, &right, &JoinType::Cross)?; + return Ok(LogicalPlan::Join(Join { + left: Arc::new(left), + right: Arc::new(right), + join_type: JoinType::Cross, + on: Vec::new(), + filter: None, + schema, + })); + } + + // Build equijoin conditions for common columns + let on = self.build_using_conditions(&common_columns, left_schema, right_schema)?; + let schema = self.build_join_schema(&left, &right, &join_type)?; + + Ok(LogicalPlan::Join(Join { + left: Arc::new(left), + right: Arc::new(right), + join_type, + on, + filter: None, + schema, + })) + } + + // Helper: Build schema for join result + fn build_join_schema( + &self, + left: &LogicalPlan, + right: &LogicalPlan, + _join_type: &JoinType, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + // For now, simply concatenate the schemas + // In a real implementation, we'd handle column name conflicts and nullable columns + let mut columns = left_schema.columns.clone(); + columns.extend(right_schema.columns.clone()); + + Ok(Arc::new(LogicalSchema::new(columns))) + } + + // Helper: Combine two schemas for expression building + fn combine_schemas(&self, left: &SchemaRef, right: &SchemaRef) -> Result { + let mut columns = left.columns.clone(); + columns.extend(right.columns.clone()); + Ok(Arc::new(LogicalSchema::new(columns))) } // Build projection @@ -1974,6 +2242,67 @@ mod tests { }; schema.add_btree_table(Arc::new(orders_table)); + // Create products table + let products_table = BTreeTable { + name: "products".to_string(), + root_page: 4, + primary_key_columns: vec![("id".to_string(), turso_parser::ast::SortOrder::Asc)], + columns: vec![ + SchemaColumn { + name: Some("id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: true, + is_rowid_alias: true, + notnull: true, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("name".to_string()), + ty: Type::Text, + ty_str: "TEXT".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("price".to_string()), + ty: Type::Real, + ty_str: "REAL".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + SchemaColumn { + name: Some("product_id".to_string()), + ty: Type::Integer, + ty_str: "INTEGER".to_string(), + primary_key: false, + is_rowid_alias: false, + notnull: false, + default: None, + unique: false, + collation: None, + hidden: false, + }, + ], + has_rowid: true, + is_strict: false, + unique_sets: vec![], + }; + schema.add_btree_table(Arc::new(products_table)); + schema } @@ -3086,4 +3415,381 @@ mod tests { _ => panic!("Expected Projection as top-level operator, got: {plan:?}"), } } + + // ===== JOIN TESTS ===== + + #[test] + fn test_inner_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u INNER JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + assert!(!join.on.is_empty(), "Should have join conditions"); + + // Check left input is users + match &*join.left { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "users"); + } + _ => panic!("Expected TableScan for left input"), + } + + // Check right input is orders + match &*join.right { + LogicalPlan::TableScan(scan) => { + assert_eq!(scan.table_name, "orders"); + } + _ => panic!("Expected TableScan for right input"), + } + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_left_join() { + let schema = create_test_schema(); + let sql = "SELECT u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 2); // name and amount + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Left); + assert!(!join.on.is_empty(), "Should have join conditions"); + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_right_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM orders o RIGHT JOIN users u ON o.user_id = u.id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Right); + assert!(!join.on.is_empty(), "Should have join conditions"); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_full_outer_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u FULL OUTER JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Full); + assert!(!join.on.is_empty(), "Should have join conditions"); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_cross_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users CROSS JOIN orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Cross); + assert!(join.on.is_empty(), "Cross join should have no conditions"); + assert!(join.filter.is_none(), "Cross join should have no filter"); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_multiple_conditions() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id AND u.age > 18"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + // Should have at least one equijoin condition + assert!(!join.on.is_empty(), "Should have join conditions"); + // Additional conditions may be in filter + // The exact distribution depends on our implementation + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_using_clause() { + let schema = create_test_schema(); + // Note: Both tables should have an 'id' column for this to work + let sql = "SELECT * FROM users JOIN orders USING (id)"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + assert!( + !join.on.is_empty(), + "USING clause should create join conditions" + ); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_natural_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users NATURAL JOIN orders"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join) => { + // Natural join finds common columns (id in this case) + // If no common columns, it becomes a cross join + assert!( + !join.on.is_empty() || join.join_type == JoinType::Cross, + "Natural join should either find common columns or become cross join" + ); + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_three_way_join() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u + JOIN orders o ON u.id = o.user_id + JOIN products p ON o.product_id = p.id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join2) => { + // Second join (with products) + assert_eq!(join2.join_type, JoinType::Inner); + match &*join2.left { + LogicalPlan::Join(join1) => { + // First join (users with orders) + assert_eq!(join1.join_type, JoinType::Inner); + } + _ => panic!("Expected nested Join for three-way join"), + } + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_mixed_join_types() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u + LEFT JOIN orders o ON u.id = o.user_id + INNER JOIN products p ON o.product_id = p.id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Join(join2) => { + // Second join should be INNER + assert_eq!(join2.join_type, JoinType::Inner); + match &*join2.left { + LogicalPlan::Join(join1) => { + // First join should be LEFT + assert_eq!(join1.join_type, JoinType::Left); + } + _ => panic!("Expected nested Join"), + } + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_filter() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id WHERE o.amount > 100"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + match &*proj.input { + LogicalPlan::Filter(filter) => { + // WHERE clause creates a Filter above the Join + match &*filter.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Filter"), + } + } + _ => panic!("Expected Filter under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_projection() { + let schema = create_test_schema(); + let sql = "SELECT u.name, o.amount FROM users u JOIN orders o ON u.id = o.user_id"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(proj) => { + assert_eq!(proj.exprs.len(), 2); // u.name and o.amount + match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Projection"), + } + } + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_with_aggregation() { + let schema = create_test_schema(); + let sql = "SELECT u.name, SUM(o.amount) + FROM users u JOIN orders o ON u.id = o.user_id + GROUP BY u.name"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Aggregate(agg) => { + assert_eq!(agg.group_expr.len(), 1); // GROUP BY u.name + assert_eq!(agg.aggr_expr.len(), 1); // SUM(o.amount) + match &*agg.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Aggregate"), + } + } + _ => panic!("Expected Aggregate"), + } + } + + #[test] + fn test_join_with_order_by() { + let schema = create_test_schema(); + let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id ORDER BY o.amount DESC"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Sort(sort) => { + assert_eq!(sort.exprs.len(), 1); + assert!(!sort.exprs[0].asc); // DESC + match &*sort.input { + LogicalPlan::Projection(proj) => match &*proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join under Projection"), + }, + _ => panic!("Expected Projection under Sort"), + } + } + _ => panic!("Expected Sort at top level"), + } + } + + #[test] + fn test_join_in_subquery() { + let schema = create_test_schema(); + let sql = "SELECT * FROM ( + SELECT u.id, u.name, o.amount + FROM users u JOIN orders o ON u.id = o.user_id + ) WHERE amount > 100"; + let plan = parse_and_build(sql, &schema).unwrap(); + + match plan { + LogicalPlan::Projection(outer_proj) => match &*outer_proj.input { + LogicalPlan::Filter(filter) => match &*filter.input { + LogicalPlan::Projection(inner_proj) => match &*inner_proj.input { + LogicalPlan::Join(join) => { + assert_eq!(join.join_type, JoinType::Inner); + } + _ => panic!("Expected Join in subquery"), + }, + _ => panic!("Expected Projection for subquery"), + }, + _ => panic!("Expected Filter"), + }, + _ => panic!("Expected Projection at top level"), + } + } + + #[test] + fn test_join_ambiguous_column() { + let schema = create_test_schema(); + // Both users and orders have an 'id' column + let sql = "SELECT id FROM users JOIN orders ON users.id = orders.user_id"; + let result = parse_and_build(sql, &schema); + // This might error or succeed depending on how we handle ambiguous columns + // For now, just check that parsing completes + match result { + Ok(_) => { + // If successful, the implementation handles ambiguous columns somehow + } + Err(_) => { + // If error, the implementation rejects ambiguous columns + } + } + } }