diff --git a/core/schema.rs b/core/schema.rs index cc87b4892..302b8e64a 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -33,6 +33,7 @@ impl Schema { } } +#[derive(Clone)] pub enum Table { BTree(Rc), Pseudo(Rc), @@ -65,6 +66,27 @@ impl Table { } } +impl PartialEq for Table { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Table::BTree(a), Table::BTree(b)) => Rc::ptr_eq(a, b), + (Table::Pseudo(a), Table::Pseudo(b)) => Rc::ptr_eq(a, b), + _ => false, + } + } +} + +impl Eq for Table {} + +impl std::hash::Hash for Table { + fn hash(&self, state: &mut H) { + match self { + Table::BTree(table) => std::ptr::hash(table.as_ref(), state), + Table::Pseudo(table) => std::ptr::hash(table.as_ref(), state), + } + } +} + pub struct BTreeTable { pub root_page: usize, pub name: String, diff --git a/core/translate.rs b/core/translate.rs index 0c2b16d3e..b14282ebe 100644 --- a/core/translate.rs +++ b/core/translate.rs @@ -3,21 +3,46 @@ use std::rc::Rc; use crate::function::AggFunc; use crate::pager::Pager; -use crate::schema::{Schema, Table}; +use crate::schema::{Column, Schema, Table}; use crate::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE}; use crate::util::normalize_ident; use crate::vdbe::{Insn, Program, ProgramBuilder}; use anyhow::Result; -use sqlite3_parser::ast; +use sqlite3_parser::ast::{self, Expr}; struct Select { columns: Vec, column_info: Vec, from: Option, + joins: Option>, limit: Option, exist_aggregation: bool, } +struct LoopInfo { + table: Table, + rewind_offset: usize, + open_cursor: usize, +} + +struct SelectContext { + /// Ordered list of opened read table loops + /// Used for generating a loop that looks like this: + /// cursor 0 = open table 0 + /// for each row in cursor 0 + /// cursor 1 = open table 1 + /// for each row in cursor 1 + /// ... + /// end cursor 1 + /// end cursor 0 + loops: Vec, +} + +struct Join { + table: Table, + info: ast::JoinedSelectTable, // FIXME: preferably this should be a reference with lifetime == Select ast expr +} + struct ColumnInfo { func: Option, args: Option>, @@ -74,13 +99,37 @@ fn build_select(schema: &Schema, select: ast::Select) -> Result { from: None, .. } => { - let column_info = analyze_columns(&columns, None); + let column_info = analyze_columns(&columns, None, &None); let exist_aggregation = column_info.iter().any(|info| info.func.is_some()); Ok(Select { columns, column_info, from: None, + joins: None, limit: select.limit.clone(), exist_aggregation, }) @@ -114,15 +164,14 @@ fn translate_select(select: Select) -> Result { let target_register = program.alloc_register(); Some(translate_expr( &mut program, - None, - None, + &select, + &SelectContext { loops: Vec::new() }, &limit.expr, target_register, - )) + )?) } else { None }; - let cursor_id = program.alloc_cursor_id(); let parsed_limit = select.limit.as_ref().and_then(|limit| { if let ast::Expr::Literal(ast::Literal::Numeric(num)) = &limit.expr { num.parse::().ok() @@ -130,28 +179,27 @@ fn translate_select(select: Select) -> Result { None } }); - let limit_insn = match (parsed_limit, &select.from) { + let from = select.from.as_ref(); + let limit_insn = match (parsed_limit, from) { (Some(0), _) => Some(program.emit_placeholder()), - (_, Some(table)) => { - let root_page = match table { - Table::BTree(table) => table.root_page, - Table::Pseudo(_) => todo!(), - }; - program.emit_insn(Insn::OpenReadAsync { - cursor_id, - root_page, - }); - program.emit_insn(Insn::OpenReadAwait); - program.emit_insn(Insn::RewindAsync { cursor_id }); - let rewind_await_offset = program.emit_placeholder(); + (_, Some(_)) => { + let select_context = translate_tables_begin(&mut program, &select); + let (register_start, register_end) = - translate_columns(&mut program, Some(cursor_id), &select); - let limit_insn = if select.exist_aggregation { - program.emit_insn(Insn::NextAsync { cursor_id }); - program.emit_insn(Insn::NextAwait { - cursor_id, - pc_if_next: rewind_await_offset, + translate_columns(&mut program, &select, &select_context)?; + + let mut limit_insn: Option = None; + if !select.exist_aggregation { + program.emit_insn(Insn::ResultRow { + start_reg: register_start, + count: register_end - register_start, }); + limit_insn = limit_reg.map(|_| program.emit_placeholder()); + } + + translate_tables_end(&mut program, &select, &select_context); + + if select.exist_aggregation { let mut target = register_start; for info in &select.column_info { if let Some(func) = &info.func { @@ -167,32 +215,14 @@ fn translate_select(select: Select) -> Result { start_reg: register_start, count: register_end - register_start, }); - limit_reg.map(|_| program.emit_placeholder()) - } else { - program.emit_insn(Insn::ResultRow { - start_reg: register_start, - count: register_end - register_start, - }); - let limit_insn = limit_reg.map(|_| program.emit_placeholder()); - program.emit_insn(Insn::NextAsync { cursor_id }); - program.emit_insn(Insn::NextAwait { - cursor_id, - pc_if_next: rewind_await_offset, - }); - limit_insn - }; - program.fixup_insn( - rewind_await_offset, - Insn::RewindAwait { - cursor_id, - pc_if_empty: program.offset(), - }, - ); + limit_insn = limit_reg.map(|_| program.emit_placeholder()); + } limit_insn } (_, None) => { assert!(!select.exist_aggregation); - let (register_start, register_end) = translate_columns(&mut program, None, &select); + let (register_start, register_end) = + translate_columns(&mut program, &select, &SelectContext { loops: Vec::new() })?; program.emit_insn(Insn::ResultRow { start_reg: register_start, count: register_end - register_start, @@ -227,11 +257,94 @@ fn translate_select(select: Select) -> Result { Ok(program.build()) } +fn translate_tables_begin(program: &mut ProgramBuilder, select: &Select) -> SelectContext { + let mut context = SelectContext { loops: Vec::new() }; + + translate_table_open_cursor(program, &mut context, select.from.as_ref().unwrap()); + + if select.joins.is_some() { + for join in select.joins.as_ref().unwrap() { + let table = &join.table; + translate_table_open_cursor(program, &mut context, table); + } + } + + let mut loop_index = 0; + translate_table_open_loop(program, &mut context, loop_index); + + loop_index += 1; + if select.joins.is_some() { + for _ in select.joins.as_ref().unwrap() { + translate_table_open_loop(program, &mut context, loop_index); + loop_index += 1; + } + } + context +} + +fn translate_tables_end( + program: &mut ProgramBuilder, + select: &Select, + select_context: &SelectContext, +) { + // iterate in reverse order as we open cursors in order + for table_loop in select_context.loops.iter().rev() { + let cursor_id = table_loop.open_cursor; + program.emit_insn(Insn::NextAsync { cursor_id }); + program.emit_insn(Insn::NextAwait { + cursor_id, + pc_if_next: table_loop.rewind_offset, + }); + program.fixup_insn( + table_loop.rewind_offset, + Insn::RewindAwait { + cursor_id: table_loop.open_cursor, + pc_if_empty: program.offset(), + }, + ); + } +} + +fn translate_table_open_cursor( + program: &mut ProgramBuilder, + select_context: &mut SelectContext, + table: &Table, +) { + let cursor_id = program.alloc_cursor_id(); + let root_page = match table { + Table::BTree(btree) => btree.root_page, + Table::Pseudo(_) => todo!(), + }; + program.emit_insn(Insn::OpenReadAsync { + cursor_id, + root_page, + }); + program.emit_insn(Insn::OpenReadAwait); + select_context.loops.push(LoopInfo { + table: table.clone(), + open_cursor: cursor_id, + rewind_offset: 0, + }); +} + +fn translate_table_open_loop( + program: &mut ProgramBuilder, + select_context: &mut SelectContext, + loop_index: usize, +) { + let table_loop = select_context.loops.get_mut(loop_index).unwrap(); + program.emit_insn(Insn::RewindAsync { + cursor_id: table_loop.open_cursor, + }); + let rewind_await_offset = program.emit_placeholder(); + table_loop.rewind_offset = rewind_await_offset; +} + fn translate_columns( program: &mut ProgramBuilder, - cursor_id: Option, select: &Select, -) -> (usize, usize) { + context: &SelectContext, +) -> Result<(usize, usize)> { let register_start = program.next_free_register(); // allocate one register as output for each col @@ -245,53 +358,83 @@ fn translate_columns( let mut target = register_start; for (col, info) in select.columns.iter().zip(select.column_info.iter()) { - translate_column(program, cursor_id, select.from.as_ref(), col, info, target); + translate_column(program, select, context, col, info, target)?; target += info.columns_to_allocate; } - (register_start, register_end) + Ok((register_start, register_end)) } fn translate_column( program: &mut ProgramBuilder, - cursor_id: Option, - table: Option<&crate::schema::Table>, + select: &Select, + context: &SelectContext, col: &sqlite3_parser::ast::ResultColumn, info: &ColumnInfo, target_register: usize, // where to store the result, in case of star it will be the start of registers added -) { +) -> Result<()> { match col { sqlite3_parser::ast::ResultColumn::Expr(expr, _) => { if info.is_aggregation_function() { let _ = - translate_aggregation(program, cursor_id, table, expr, info, target_register); + translate_aggregation(program, select, context, expr, info, target_register)?; } else { - let _ = translate_expr(program, cursor_id, table, expr, target_register); + let _ = translate_expr(program, select, context, expr, target_register)?; } } sqlite3_parser::ast::ResultColumn::Star => { - let table = table.unwrap(); - for (i, col) in table.columns().iter().enumerate() { - if table.column_is_rowid_alias(col) { - program.emit_insn(Insn::RowId { - cursor_id: cursor_id.unwrap(), - dest: target_register + i, - }); - } else { - program.emit_insn(Insn::Column { - column: i, - dest: target_register + i, - cursor_id: cursor_id.unwrap(), - }); + let table = select.from.as_ref().unwrap(); + translate_table_star(table, program, context, target_register); + let root_table_columns = table.columns().len(); + + if select.joins.is_some() { + for join in select.joins.as_ref().unwrap() { + let table = &join.table; + translate_table_star( + table, + program, + context, + target_register + root_table_columns, + ); } } } sqlite3_parser::ast::ResultColumn::TableStar(_) => todo!(), } + Ok(()) +} + +fn translate_table_star( + table: &Table, + program: &mut ProgramBuilder, + context: &SelectContext, + target_register: usize, +) { + let table_cursor = context + .loops + .iter() + .find(|v| v.table == *table) + .unwrap() + .open_cursor; + for (i, col) in table.columns().iter().enumerate() { + if table.column_is_rowid_alias(col) { + program.emit_insn(Insn::RowId { + cursor_id: table_cursor, + dest: target_register + i, + }); + } else { + program.emit_insn(Insn::Column { + column: i, + dest: target_register + i, + cursor_id: table_cursor, + }); + } + } } fn analyze_columns( columns: &Vec, table: Option<&crate::schema::Table>, + joins: &Option>, ) -> Vec { let mut column_information_list = Vec::with_capacity(columns.len()); for column in columns { @@ -299,6 +442,11 @@ fn analyze_columns( info.columns_to_allocate = 1; if let sqlite3_parser::ast::ResultColumn::Star = column { info.columns_to_allocate = table.unwrap().columns().len(); + if joins.is_some() { + for join in joins.as_ref().unwrap() { + info.columns_to_allocate += join.table.columns().len(); + } + } } else { analyze_column(column, &mut info); } @@ -309,51 +457,58 @@ fn analyze_columns( /// Analyze a column expression. /// -/// The function walks a column expression trying to find aggregation functions. -/// If it finds one it will save information about it. +/// This function will walk all columns and find information about: +/// * Aggregation functions. fn analyze_column(column: &sqlite3_parser::ast::ResultColumn, column_info_out: &mut ColumnInfo) { match column { - sqlite3_parser::ast::ResultColumn::Expr(expr, _) => match expr { - ast::Expr::FunctionCall { - name, - distinctness: _, - args, - filter_over: _, - } => { - let func_type = match normalize_ident(name.0.as_str()).as_str() { - "avg" => Some(AggFunc::Avg), - "count" => Some(AggFunc::Count), - "group_concat" => Some(AggFunc::GroupConcat), - "max" => Some(AggFunc::Max), - "min" => Some(AggFunc::Min), - "string_agg" => Some(AggFunc::StringAgg), - "sum" => Some(AggFunc::Sum), - "total" => Some(AggFunc::Total), - _ => None, - }; - if func_type.is_none() { - analyze_column(column, column_info_out); - } else { - column_info_out.func = func_type; - // TODO(pere): use lifetimes for args? Arenas would be lovely here :( - column_info_out.args.clone_from(args); - } - } - ast::Expr::FunctionCallStar { .. } => todo!(), - _ => {} - }, + sqlite3_parser::ast::ResultColumn::Expr(expr, _) => analyze_expr(expr, column_info_out), ast::ResultColumn::Star => {} ast::ResultColumn::TableStar(_) => {} } } +fn analyze_expr(expr: &Expr, column_info_out: &mut ColumnInfo) { + match expr { + ast::Expr::FunctionCall { + name, + distinctness: _, + args, + filter_over: _, + } => { + let func_type = match normalize_ident(name.0.as_str()).as_str() { + "avg" => Some(AggFunc::Avg), + "count" => Some(AggFunc::Count), + "group_concat" => Some(AggFunc::GroupConcat), + "max" => Some(AggFunc::Max), + "min" => Some(AggFunc::Min), + "string_agg" => Some(AggFunc::StringAgg), + "sum" => Some(AggFunc::Sum), + "total" => Some(AggFunc::Total), + _ => None, + }; + if func_type.is_none() { + let args = args.as_ref().unwrap(); + if args.len() > 0 { + analyze_expr(&args.get(0).unwrap(), column_info_out); + } + } else { + column_info_out.func = func_type; + // TODO(pere): use lifetimes for args? Arenas would be lovely here :( + column_info_out.args.clone_from(args); + } + } + ast::Expr::FunctionCallStar { .. } => todo!(), + _ => {} + } +} + fn translate_expr( program: &mut ProgramBuilder, - cursor_id: Option, - table: Option<&crate::schema::Table>, + select: &Select, + context: &SelectContext, expr: &ast::Expr, target_register: usize, -) -> usize { +) -> Result { match expr { ast::Expr::Between { .. } => todo!(), ast::Expr::Binary(_, _, _) => todo!(), @@ -365,20 +520,21 @@ fn translate_expr( ast::Expr::FunctionCall { .. } => todo!(), ast::Expr::FunctionCallStar { .. } => todo!(), ast::Expr::Id(ident) => { - let (idx, col) = table.unwrap().get_column(&ident.0).unwrap(); + // let (idx, col) = table.unwrap().get_column(&ident.0).unwrap(); + let (idx, col, cursor_id) = resolve_ident_table(&ident.0, select, context)?; if col.primary_key { program.emit_insn(Insn::RowId { - cursor_id: cursor_id.unwrap(), + cursor_id, dest: target_register, }); } else { program.emit_insn(Insn::Column { column: idx, dest: target_register, - cursor_id: cursor_id.unwrap(), + cursor_id, }); } - target_register + Ok(target_register) } ast::Expr::InList { .. } => todo!(), ast::Expr::InSelect { .. } => todo!(), @@ -400,14 +556,14 @@ fn translate_expr( dest: target_register, }); } - target_register + Ok(target_register) } ast::Literal::String(s) => { program.emit_insn(Insn::String8 { value: s[1..s.len() - 1].to_string(), dest: target_register, }); - target_register + Ok(target_register) } ast::Literal::Blob(_) => todo!(), ast::Literal::Keyword(_) => todo!(), @@ -427,10 +583,56 @@ fn translate_expr( } } +fn resolve_ident_table<'a>( + ident: &String, + select: &'a Select, + context: &SelectContext, +) -> Result<(usize, &'a Column, usize)> { + let table = select.from.as_ref().unwrap(); + + let res = table + .columns() + .iter() + .enumerate() + .find(|(_, col)| col.name == *ident); + if res.is_some() { + let (idx, col) = res.unwrap(); + let cursor_id = context + .loops + .iter() + .find(|l| l.table == *table) + .unwrap() + .open_cursor; + return Ok((idx, col, cursor_id)); + } + + if select.joins.is_some() { + for join in select.joins.as_ref().unwrap().iter() { + let res = join + .table + .columns() + .iter() + .enumerate() + .find(|(_, col)| col.name == *ident); + if res.is_some() { + let (idx, col) = res.unwrap(); + let cursor_id = context + .loops + .iter() + .find(|l| l.table == *table) + .unwrap() + .open_cursor; + return Ok((idx, col, cursor_id)); + } + } + } + anyhow::bail!("Parse error: column with name {} not found", ident.as_str()); +} + fn translate_aggregation( program: &mut ProgramBuilder, - cursor_id: Option, - table: Option<&crate::schema::Table>, + select: &Select, + context: &SelectContext, expr: &ast::Expr, info: &ColumnInfo, target_register: usize, @@ -447,7 +649,7 @@ fn translate_aggregation( } let expr = &args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, cursor_id, table, expr, expr_reg); + let _ = translate_expr(program, select, context, expr, expr_reg)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -507,7 +709,7 @@ fn translate_aggregation( } let expr = &args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, cursor_id, table, expr, expr_reg); + let _ = translate_expr(program, select, context, expr, expr_reg)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg,