From ea793e412625ab9c31f66715b557919a09a9afe1 Mon Sep 17 00:00:00 2001 From: jussisaurio Date: Sun, 14 Jul 2024 01:04:08 +0300 Subject: [PATCH] Inner join, table aliases, qualified column names --- core/schema.rs | 5 +- core/translate.rs | 381 +++++++++++++++++++++++++++++++++++----------- core/vdbe.rs | 56 +++++-- testing/all.test | 39 +++++ 4 files changed, 383 insertions(+), 98 deletions(-) diff --git a/core/schema.rs b/core/schema.rs index e8a38e49f..7f41dffc0 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -33,7 +33,7 @@ impl Schema { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum Table { BTree(Rc), Pseudo(Rc), @@ -96,6 +96,7 @@ impl PartialEq for Table { } } +#[derive(Debug)] pub struct BTreeTable { pub root_page: usize, pub name: String, @@ -150,6 +151,7 @@ impl BTreeTable { } } +#[derive(Debug)] pub struct PseudoTable { pub columns: Vec, } @@ -295,6 +297,7 @@ pub fn _build_pseudo_table(columns: &[ResultColumn]) -> PseudoTable { table } +#[derive(Debug)] pub struct Column { pub name: String, pub ty: Type, diff --git a/core/translate.rs b/core/translate.rs index 9fdf595b1..74ac0f0dd 100644 --- a/core/translate.rs +++ b/core/translate.rs @@ -37,7 +37,8 @@ struct LoopInfo { struct SrcTable { table: Table, - _join_info: Option, // FIXME: preferably this should be a reference with lifetime == Select ast expr + alias: Option, + join_info: Option, // FIXME: preferably this should be a reference with lifetime == Select ast expr } struct ColumnInfo { @@ -91,14 +92,21 @@ fn build_select(schema: &Schema, select: ast::Select) -> Result { let mut joins = Vec::new(); joins.push(SrcTable { table: Table::BTree(table.clone()), - _join_info: None, + alias: maybe_alias, + join_info: None, }); if let Some(selected_joins) = from.joins { for join in selected_joins { - let table_name = match &join.table { - ast::SelectTable::Table(name, ..) => name.name.clone(), + let (table_name, maybe_alias) = match &join.table { + ast::SelectTable::Table(name, alias, ..) => ( + name.name.clone(), + alias.clone().map(|als| match als { + ast::As::As(alias) => alias, // users as u + ast::As::Elided(alias) => alias, // users u + }), + ), _ => todo!(), }; let table_name = &table_name.0; + let maybe_alias = maybe_alias.map(|als| als.0); let table = match schema.get_table(table_name) { Some(table) => table, None => anyhow::bail!("Parse error: no such table: {}", table_name), }; joins.push(SrcTable { table: Table::BTree(table), - _join_info: Some(join.clone()), + alias: maybe_alias, + join_info: Some(join.clone()), }); } } @@ -205,9 +222,7 @@ fn translate_select(mut select: Select) -> Result { }; if !select.src_tables.is_empty() { - translate_tables_begin(&mut program, &mut select); - - let where_maybe = insert_where_clause_instructions(&select, &mut program)?; + let condition_label_maybe = translate_tables_begin(&mut program, &mut select)?; let (register_start, register_end) = translate_columns(&mut program, &select)?; @@ -219,8 +234,8 @@ fn translate_select(mut select: Select) -> Result { emit_limit_insn(&limit_info, &mut program); } - if let Some(where_clause_label) = where_maybe { - program.resolve_label(where_clause_label, program.offset()); + if let Some(condition_label) = condition_label_maybe { + program.resolve_label(condition_label, program.offset()); } translate_tables_end(&mut program, &select); @@ -245,7 +260,7 @@ fn translate_select(mut select: Select) -> Result { } } else { assert!(!select.exist_aggregation); - let where_maybe = insert_where_clause_instructions(&select, &mut program)?; + let where_maybe = translate_where(&select, &mut program)?; let (register_start, register_end) = translate_columns(&mut program, &select)?; if let Some(where_clause_label) = where_maybe { program.resolve_label(where_clause_label, program.offset() + 1); @@ -288,10 +303,7 @@ fn emit_limit_insn(limit_info: &Option, program: &mut ProgramBuilder) } } -fn insert_where_clause_instructions( - select: &Select, - program: &mut ProgramBuilder, -) -> Result> { +fn translate_where(select: &Select, program: &mut ProgramBuilder) -> Result> { if let Some(w) = &select.where_clause { let label = program.allocate_label(); translate_condition_expr(program, select, w, label)?; @@ -301,16 +313,76 @@ fn insert_where_clause_instructions( } } -fn translate_tables_begin(program: &mut ProgramBuilder, select: &mut Select) { +fn translate_conditions( + program: &mut ProgramBuilder, + select: &Select, +) -> Result> { + // FIXME: clone() + // TODO: only supports INNER JOIN on a single condition atm, e.g. SELECT * FROM a JOIN b ON a.id = b.id, no AND/OR + let join_constraints = select + .src_tables + .iter() + .map(|v| v.join_info.clone()) + .filter_map(|v| v.map(|v| v.constraint)) + .flatten() + .collect::>(); + // TODO: only supports one JOIN; -> add support for multiple JOINs, e.g. SELECT * FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id + if join_constraints.len() > 1 { + anyhow::bail!("Parse error: multiple JOINs not supported"); + } + + let maybe_join = join_constraints.first(); + + match (&select.where_clause, maybe_join) { + (Some(where_clause), Some(join)) => { + match join { + ast::JoinConstraint::On(expr) => { + // Combine where clause and join condition + let label = program.allocate_label(); + translate_condition_expr(program, select, where_clause, label)?; + translate_condition_expr(program, select, expr, label)?; + Ok(Some(label)) + } + ast::JoinConstraint::Using(_) => { + todo!(); + } + } + } + (None, None) => { + Ok(None) + } + (Some(where_clause), None) => { + let label = program.allocate_label(); + translate_condition_expr(program, select, where_clause, label)?; + Ok(Some(label)) + } + (None, Some(join)) => match join { + ast::JoinConstraint::On(expr) => { + let label = program.allocate_label(); + translate_condition_expr(program, select, expr, label)?; + Ok(Some(label)) + } + ast::JoinConstraint::Using(_) => { + todo!(); + } + }, + } +} + +fn translate_tables_begin( + program: &mut ProgramBuilder, + select: &mut Select, +) -> Result> { for join in &select.src_tables { - let table = &join.table; - let loop_info = translate_table_open_cursor(program, table); + let loop_info = translate_table_open_cursor(program, join); select.loops.push(loop_info); } for loop_info in &mut select.loops { translate_table_open_loop(program, loop_info); } + + translate_conditions(program, select) } fn translate_tables_end(program: &mut ProgramBuilder, select: &Select) { @@ -326,9 +398,13 @@ fn translate_tables_end(program: &mut ProgramBuilder, select: &Select) { } } -fn translate_table_open_cursor(program: &mut ProgramBuilder, table: &Table) -> LoopInfo { - let cursor_id = program.alloc_cursor_id(table.clone()); - let root_page = match table { +fn translate_table_open_cursor(program: &mut ProgramBuilder, table: &SrcTable) -> LoopInfo { + let table_identifier = match &table.alias { + Some(alias) => alias.clone(), + None => table.table.get_name().to_string(), + }; + let cursor_id = program.alloc_cursor_id(table_identifier, table.table.clone()); + let root_page = match &table.table { Table::BTree(btree) => btree.root_page, Table::Pseudo(_) => todo!(), }; @@ -398,9 +474,8 @@ fn translate_column( ast::ResultColumn::Star => { let mut target_register = target_register; for join in &select.src_tables { - let table = &join.table; - translate_table_star(table, program, target_register); - target_register += table.columns().len(); + translate_table_star(join, program, target_register); + target_register += &join.table.columns().len(); } } ast::ResultColumn::TableStar(_) => todo!(), @@ -408,8 +483,13 @@ fn translate_column( Ok(()) } -fn translate_table_star(table: &Table, program: &mut ProgramBuilder, target_register: usize) { - let table_cursor = program.resolve_cursor_id(table); +fn translate_table_star(table: &SrcTable, program: &mut ProgramBuilder, target_register: usize) { + let table_identifier = match &table.alias { + Some(alias) => alias.clone(), + None => table.table.get_name().to_string(), + }; + let table_cursor = program.resolve_cursor_id(&table_identifier); + let table = &table.table; for (i, col) in table.columns().iter().enumerate() { let col_target_register = target_register + i; if table.column_is_rowid_alias(col) { @@ -541,7 +621,7 @@ fn translate_condition_expr( rhs: e2_reg, target_pc: jump_target, }, - _ => todo!(), + other => todo!("{:?}", other), }); Ok(()) } @@ -572,6 +652,24 @@ fn translate_condition_expr( } } +fn wrap_eval_jump_expr( + program: &mut ProgramBuilder, + insn: Insn, + target_register: usize, + if_true_label: BranchOffset, +) { + program.emit_insn(Insn::Integer { + value: 1, // emit True by default + dest: target_register, + }); + program.emit_insn_with_label_dependency(insn, if_true_label); + program.emit_insn(Insn::Integer { + value: 0, // emit False if we reach this point (no jump) + dest: target_register, + }); + program.preassign_label_to_next_insn(if_true_label); +} + fn translate_expr( program: &mut ProgramBuilder, select: &Select, @@ -585,50 +683,95 @@ fn translate_expr( let e2_reg = program.alloc_register(); let _ = translate_expr(program, select, e1, e1_reg)?; let _ = translate_expr(program, select, e2, e2_reg)?; - program.emit_insn(match op { - ast::Operator::NotEquals => Insn::Ne { - lhs: e1_reg, - rhs: e2_reg, - target_pc: program.offset() + 3, // jump to "emit True" instruction - }, - ast::Operator::Equals => Insn::Eq { - lhs: e1_reg, - rhs: e2_reg, - target_pc: program.offset() + 3, - }, - ast::Operator::Less => Insn::Lt { - lhs: e1_reg, - rhs: e2_reg, - target_pc: program.offset() + 3, - }, - ast::Operator::LessEquals => Insn::Le { - lhs: e1_reg, - rhs: e2_reg, - target_pc: program.offset() + 3, - }, - ast::Operator::Greater => Insn::Gt { - lhs: e1_reg, - rhs: e2_reg, - target_pc: program.offset() + 3, - }, - ast::Operator::GreaterEquals => Insn::Ge { - lhs: e1_reg, - rhs: e2_reg, - target_pc: program.offset() + 3, - }, - _ => todo!(), - }); - program.emit_insn(Insn::Integer { - value: 0, // emit False - dest: target_register, - }); - program.emit_insn(Insn::Goto { - target_pc: program.offset() + 2, - }); - program.emit_insn(Insn::Integer { - value: 1, // emit True - dest: target_register, - }); + + match op { + ast::Operator::NotEquals => { + let if_true_label = program.allocate_label(); + wrap_eval_jump_expr( + program, + Insn::Ne { + lhs: e1_reg, + rhs: e2_reg, + target_pc: if_true_label, + }, + target_register, + if_true_label, + ); + } + ast::Operator::Equals => { + let if_true_label = program.allocate_label(); + wrap_eval_jump_expr( + program, + Insn::Eq { + lhs: e1_reg, + rhs: e2_reg, + target_pc: if_true_label, + }, + target_register, + if_true_label, + ); + } + ast::Operator::Less => { + let if_true_label = program.allocate_label(); + wrap_eval_jump_expr( + program, + Insn::Lt { + lhs: e1_reg, + rhs: e2_reg, + target_pc: if_true_label, + }, + target_register, + if_true_label, + ); + } + ast::Operator::LessEquals => { + let if_true_label = program.allocate_label(); + wrap_eval_jump_expr( + program, + Insn::Le { + lhs: e1_reg, + rhs: e2_reg, + target_pc: if_true_label, + }, + target_register, + if_true_label, + ); + } + ast::Operator::Greater => { + let if_true_label = program.allocate_label(); + wrap_eval_jump_expr( + program, + Insn::Gt { + lhs: e1_reg, + rhs: e2_reg, + target_pc: if_true_label, + }, + target_register, + if_true_label, + ); + } + ast::Operator::GreaterEquals => { + let if_true_label = program.allocate_label(); + wrap_eval_jump_expr( + program, + Insn::Ge { + lhs: e1_reg, + rhs: e2_reg, + target_pc: if_true_label, + }, + target_register, + if_true_label, + ); + } + ast::Operator::Add => { + program.emit_insn(Insn::Add { + lhs: e1_reg, + rhs: e2_reg, + dest: target_register, + }); + } + other_unimplemented => todo!("{:?}", other_unimplemented), + } Ok(target_register) } ast::Expr::Case { .. } => todo!(), @@ -753,7 +896,23 @@ fn translate_expr( ast::Expr::Name(_) => todo!(), ast::Expr::NotNull(_) => todo!(), ast::Expr::Parenthesized(_) => todo!(), - ast::Expr::Qualified(_, _) => todo!(), + ast::Expr::Qualified(tbl, ident) => { + let (idx, col, cursor_id) = resolve_ident_qualified(program, &tbl.0, &ident.0, select)?; + if col.primary_key { + program.emit_insn(Insn::RowId { + cursor_id, + dest: target_register, + }); + } else { + program.emit_insn(Insn::Column { + column: idx, + dest: target_register, + cursor_id, + }); + } + maybe_apply_affinity(col, target_register, program); + Ok(target_register) + } ast::Expr::Raise(_, _) => todo!(), ast::Expr::Subquery(_) => todo!(), ast::Expr::Unary(_, _) => todo!(), @@ -761,25 +920,77 @@ fn translate_expr( } } +fn resolve_ident_qualified<'a>( + program: &ProgramBuilder, + table_name: &String, + ident: &String, + select: &'a Select, +) -> Result<(usize, &'a Column, usize)> { + for join in &select.src_tables { + match join.table { + Table::BTree(ref table) => { + let table_identifier = match &join.alias { + Some(alias) => alias.clone(), + None => table.name.to_string(), + }; + if table_identifier == *table_name { + let res = table + .columns + .iter() + .enumerate() + .find(|(_, col)| col.name == *ident); + if res.is_some() { + let (idx, col) = res.unwrap(); + let cursor_id = program.resolve_cursor_id(&table_identifier); + return Ok((idx, col, cursor_id)); + } + } + } + Table::Pseudo(_) => todo!(), + } + } + anyhow::bail!( + "Parse error: column with qualified name {}.{} not found", + table_name, + ident + ); +} + fn resolve_ident_table<'a>( program: &ProgramBuilder, ident: &String, select: &'a Select, ) -> Result<(usize, &'a Column, usize)> { + let mut found = Vec::new(); for join in &select.src_tables { - let res = join - .table - .columns() - .iter() - .enumerate() - .find(|(_, col)| col.name == *ident); - if res.is_some() { - let (idx, col) = res.unwrap(); - let cursor_id = program.resolve_cursor_id(&join.table); - return Ok((idx, col, cursor_id)); + match join.table { + Table::BTree(ref table) => { + let table_identifier = match &join.alias { + Some(alias) => alias.clone(), + None => table.name.to_string(), + }; + let res = table + .columns + .iter() + .enumerate() + .find(|(_, col)| col.name == *ident); + if res.is_some() { + let (idx, col) = res.unwrap(); + let cursor_id = program.resolve_cursor_id(&table_identifier); + found.push((idx, col, cursor_id)); + } + } + Table::Pseudo(_) => todo!(), } } - anyhow::bail!("Parse error: column with name {} not found", ident.as_str()); + if found.len() == 1 { + return Ok(found[0]); + } + if found.is_empty() { + anyhow::bail!("Parse error: column with name {} not found", ident.as_str()); + } + + anyhow::bail!("Parse error: ambiguous column name {}", ident.as_str()); } fn translate_aggregation( diff --git a/core/vdbe.rs b/core/vdbe.rs index e9e4df755..a00018301 100644 --- a/core/vdbe.rs +++ b/core/vdbe.rs @@ -26,6 +26,12 @@ pub enum Insn { Null { dest: usize, }, + // Add two registers and store the result in a third register. + Add { + lhs: usize, + rhs: usize, + dest: usize, + }, // If the given register is not NULL, jump to the given PC. NotNull { reg: usize, @@ -232,7 +238,7 @@ pub struct ProgramBuilder { unresolved_labels: Vec>, next_insn_label: Option, // Cursors that are referenced by the program. Indexed by CursorID. - cursor_ref: Vec, + cursor_ref: Vec<(String, Table)>, } impl ProgramBuilder { @@ -265,13 +271,10 @@ impl ProgramBuilder { self.next_free_register } - pub fn alloc_cursor_id(&mut self, table: Table) -> usize { + pub fn alloc_cursor_id(&mut self, table_identifier: String, table: Table) -> usize { let cursor = self.next_free_cursor_id; self.next_free_cursor_id += 1; - if self.cursor_ref.iter().any(|t| *t == table) { - todo!("duplicate table is unhandled. see how resolve_ident_table() calls resolve_cursor_id") - } - self.cursor_ref.push(table); + self.cursor_ref.push((table_identifier, table)); assert!(self.cursor_ref.len() == self.next_free_cursor_id); cursor } @@ -293,7 +296,7 @@ impl ProgramBuilder { } pub fn emit_constant_insns(&mut self) { - self.insns.extend(self.constant_insns.drain(..)); + self.insns.append(&mut self.constant_insns); } pub fn emit_insn_with_label_dependency(&mut self, insn: Insn, label: BranchOffset) { @@ -449,8 +452,11 @@ impl ProgramBuilder { } // translate table to cursor id - pub fn resolve_cursor_id(&self, table: &Table) -> CursorID { - self.cursor_ref.iter().position(|t| t == table).unwrap() + pub fn resolve_cursor_id(&self, table_identifier: &str) -> CursorID { + self.cursor_ref + .iter() + .position(|(t_ident, _)| *t_ident == table_identifier) + .unwrap() } pub fn build(self) -> Program { @@ -503,7 +509,7 @@ impl ProgramState { pub struct Program { pub max_registers: usize, pub insns: Vec, - pub cursor_ref: Vec
, + pub cursor_ref: Vec<(String, Table)>, } impl Program { @@ -539,6 +545,23 @@ impl Program { assert!(*target_pc >= 0); state.pc = *target_pc; } + Insn::Add { lhs, rhs, dest } => { + let lhs = *lhs; + let rhs = *rhs; + let dest = *dest; + match (&state.registers[lhs], &state.registers[rhs]) { + (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { + state.registers[dest] = OwnedValue::Integer(lhs + rhs); + } + (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => { + state.registers[dest] = OwnedValue::Float(lhs + rhs); + } + _ => { + todo!(); + } + } + state.pc += 1; + } Insn::Null { dest } => { state.registers[*dest] = OwnedValue::Null; state.pc += 1; @@ -1212,6 +1235,15 @@ fn insn_to_str(program: &Program, addr: InsnReference, insn: &Insn, indent: Stri 0, format!("Start at {}", target_pc), ), + Insn::Add { lhs, rhs, dest } => ( + "Add", + *lhs as i32, + *rhs as i32, + *dest as i32, + OwnedValue::Text(Rc::new("".to_string())), + 0, + format!("r[{}]=r[{}]+r[{}]", dest, lhs, rhs), + ), Insn::Null { dest } => ( "Null", *dest as i32, @@ -1377,7 +1409,7 @@ fn insn_to_str(program: &Program, addr: InsnReference, insn: &Insn, indent: Stri column, dest, } => { - let table = &program.cursor_ref[*cursor_id]; + let (_, table) = &program.cursor_ref[*cursor_id]; ( "Column", *cursor_id as i32, @@ -1513,7 +1545,7 @@ fn insn_to_str(program: &Program, addr: InsnReference, insn: &Insn, indent: Stri format!( "r[{}]={}.rowid", dest, - &program.cursor_ref[*cursor_id].get_name() + &program.cursor_ref[*cursor_id].1.get_name() ), ), Insn::DecrJumpZero { reg, target_pc } => ( diff --git a/testing/all.test b/testing/all.test index c22335572..75f28a809 100755 --- a/testing/all.test +++ b/testing/all.test @@ -174,3 +174,42 @@ do_execsql_test coalesce-from-table-column { do_execsql_test coalesce-from-table-multiple-columns { select coalesce(NULL, age), coalesce(NULL, id) from users where age = 94 limit 1; } {94|1} + +do_execsql_test inner-join-pk { + select users.first_name as user_name, products.name as product_name from users join products on users.id = products.id; +} {Jamie|hat +Cindy|cap +Tommy|shirt +Jennifer|sweater +Edward|sweatshirt +Nicholas|shorts +Aimee|jeans +Rachel|sneakers +Matthew|boots +Daniel|coat +Travis|accessories} + +do_execsql_test inner-join-non-pk-unqualified { + select first_name, name from users join products on first_name != name limit 1; +} {Jamie|hat} + +do_execsql_test inner-join-non-pk-qualified { + select users.first_name as user_name, products.name as product_name from users join products on users.first_name = products.name; +} {} + +do_execsql_test inner-join-self { + select u1.first_name as user_name, u2.first_name as neighbor_name from users u1 join users as u2 on u1.id = u2.id + 1 limit 1; +} {Cindy|Jamie} + +do_execsql_test inner-join-self-with-where { + select u1.first_name as user_name, u2.first_name as neighbor_name from users u1 join users as u2 on u1.id = u2.id + 1 where u1.id = 5 limit 1; +} {Edward|Jennifer} + +do_execsql_test inner-join-with-where-2 { + select u.first_name from users u join products as p on u.first_name != p.name where u.last_name = 'Williams' limit 1; +} {Laura} + +do_execsql_test select-add { + select u.age + 1 from users u where u.age = 91 limit 1; +} {92} +