From 865b3a04e9ff6c2d2010d20c028c9d809da3a554 Mon Sep 17 00:00:00 2001 From: Bennett Clement Date: Mon, 22 Jul 2024 00:27:46 +0800 Subject: [PATCH] Implement orderby translation --- core/expr.rs | 33 +++++--- core/function.rs | 1 + core/schema.rs | 3 +- core/select.rs | 2 + core/translate.rs | 179 +++++++++++++++++++++++++++++++++++-------- core/vdbe.rs | 107 +++++++++++++++++++------- core/where_clause.rs | 91 ++++++++++++++++------ 7 files changed, 325 insertions(+), 91 deletions(-) diff --git a/core/expr.rs b/core/expr.rs index fbdfff86d..c5504b4be 100644 --- a/core/expr.rs +++ b/core/expr.rs @@ -78,6 +78,7 @@ pub fn build_select<'a>(schema: &Schema, select: &'a ast::Select) -> Result(schema: &Schema, select: &'a ast::Select) -> Result, ) -> Result { match expr { ast::Expr::Between { .. } => todo!(), ast::Expr::Binary(e1, op, e2) => { let e1_reg = program.alloc_register(); let e2_reg = program.alloc_register(); - let _ = translate_expr(program, select, e1, e1_reg)?; - let _ = translate_expr(program, select, e2, e2_reg)?; + let _ = translate_expr(program, select, e1, e1_reg, cursor_hint)?; + let _ = translate_expr(program, select, e2, e2_reg, cursor_hint)?; match op { ast::Operator::NotEquals => { @@ -250,7 +253,13 @@ pub fn translate_expr( // whenever a not null check succeeds, we jump to the end of the series let label_coalesce_end = program.allocate_label(); for (index, arg) in args.iter().enumerate() { - let reg = translate_expr(program, select, arg, target_register)?; + let reg = translate_expr( + program, + select, + arg, + target_register, + cursor_hint, + )?; if index < args.len() - 1 { program.emit_insn_with_label_dependency( Insn::NotNull { @@ -282,7 +291,7 @@ pub fn translate_expr( }; for arg in args { let reg = program.alloc_register(); - let _ = translate_expr(program, select, &arg, reg)?; + let _ = translate_expr(program, select, &arg, reg, cursor_hint)?; match arg { ast::Expr::Literal(_) => program.mark_last_insn_constant(), _ => {} @@ -315,7 +324,7 @@ pub fn translate_expr( }; let regs = program.alloc_register(); - translate_expr(program, select, &args[0], regs)?; + translate_expr(program, select, &args[0], regs, cursor_hint)?; program.emit_insn(Insn::Function { start_reg: regs, dest: target_register, @@ -356,7 +365,7 @@ pub fn translate_expr( for arg in args.iter() { let reg = program.alloc_register(); - translate_expr(program, select, arg, reg)?; + translate_expr(program, select, arg, reg, cursor_hint)?; if let ast::Expr::Literal(_) = arg { program.mark_last_insn_constant(); } @@ -378,7 +387,8 @@ pub fn translate_expr( ast::Expr::FunctionCallStar { .. } => todo!(), ast::Expr::Id(ident) => { // let (idx, col) = table.unwrap().get_column(&ident.0).unwrap(); - let (idx, col, cursor_id) = resolve_ident_table(program, &ident.0, select)?; + let (idx, col, cursor_id) = + resolve_ident_table(program, &ident.0, select, cursor_hint)?; if col.primary_key { program.emit_insn(Insn::RowId { cursor_id, @@ -439,7 +449,8 @@ pub fn translate_expr( ast::Expr::NotNull(_) => todo!(), ast::Expr::Parenthesized(_) => todo!(), ast::Expr::Qualified(tbl, ident) => { - let (idx, col, cursor_id) = resolve_ident_qualified(program, &tbl.0, &ident.0, select)?; + let (idx, col, cursor_id) = + resolve_ident_qualified(program, &tbl.0, &ident.0, select, cursor_hint)?; if col.primary_key { program.emit_insn(Insn::RowId { cursor_id, @@ -563,6 +574,7 @@ pub fn resolve_ident_qualified<'a>( table_name: &String, ident: &String, select: &'a Select, + cursor_hint: Option, ) -> Result<(usize, &'a Column, usize)> { for join in &select.src_tables { match join.table { @@ -579,7 +591,7 @@ pub fn resolve_ident_qualified<'a>( .find(|(_, col)| col.name == *ident); if res.is_some() { let (idx, col) = res.unwrap(); - let cursor_id = program.resolve_cursor_id(&table_identifier); + let cursor_id = program.resolve_cursor_id(&table_identifier, cursor_hint); return Ok((idx, col, cursor_id)); } } @@ -598,6 +610,7 @@ pub fn resolve_ident_table<'a>( program: &ProgramBuilder, ident: &String, select: &'a Select, + cursor_hint: Option, ) -> Result<(usize, &'a Column, usize)> { let mut found = Vec::new(); for join in &select.src_tables { @@ -614,7 +627,7 @@ pub fn resolve_ident_table<'a>( .find(|(_, col)| col.name == *ident); if res.is_some() { let (idx, col) = res.unwrap(); - let cursor_id = program.resolve_cursor_id(&table_identifier); + let cursor_id = program.resolve_cursor_id(&table_identifier, cursor_hint); found.push((idx, col, cursor_id)); } } diff --git a/core/function.rs b/core/function.rs index 224e96500..0ec15ce28 100644 --- a/core/function.rs +++ b/core/function.rs @@ -56,6 +56,7 @@ impl ToString for SingleRowFunc { } } +#[derive(Debug)] pub enum Func { Agg(AggFunc), SingleRow(SingleRowFunc), diff --git a/core/schema.rs b/core/schema.rs index 7f41dffc0..5544c20bf 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -1,3 +1,4 @@ +use crate::types::OwnedRecord; use crate::util::normalize_ident; use anyhow::Result; use core::fmt; @@ -54,7 +55,7 @@ impl Table { pub fn get_name(&self) -> &str { match self { Table::BTree(table) => &table.name, - Table::Pseudo(table) => &table.columns[0].name, + Table::Pseudo(table) => "pseudo", } } diff --git a/core/select.rs b/core/select.rs index ba23b1b4a..f54e581f8 100644 --- a/core/select.rs +++ b/core/select.rs @@ -8,6 +8,7 @@ pub struct SrcTable<'a> { pub join_info: Option<&'a ast::JoinedSelectTable>, // FIXME: preferably this should be a reference with lifetime == Select ast expr } +#[derive(Debug)] pub struct ColumnInfo<'a> { pub func: Option, pub args: &'a Option>, @@ -39,6 +40,7 @@ pub struct Select<'a> { pub column_info: Vec>, pub src_tables: Vec>, // Tables we use to get data from. This includes "from" and "joins" pub limit: &'a Option, + pub order_by: &'a Option>, pub exist_aggregation: bool, pub where_clause: &'a Option, /// Ordered list of opened read table loops diff --git a/core/translate.rs b/core/translate.rs index 14f1ff207..75a4a6290 100644 --- a/core/translate.rs +++ b/core/translate.rs @@ -20,6 +20,13 @@ struct LimitInfo { goto_label: BranchOffset, } +#[derive(Debug)] +struct SortInfo { + sorter_cursor: usize, + sorter_reg: usize, + count: usize, +} + /// Translate SQL statement into bytecode program. pub fn translate( schema: &Schema, @@ -49,10 +56,24 @@ fn translate_select(mut select: Select) -> Result { ); let start_offset = program.offset(); + let mut sort_info = if let Some(_) = select.order_by { + let sorter_cursor = program.alloc_cursor_id(None, None); + program.emit_insn(Insn::SorterOpen { + cursor_id: sorter_cursor, + }); + Some(SortInfo { + sorter_cursor, + sorter_reg: 0, // will be overwritten later + count: 0, // will be overwritten later + }) + } else { + None + }; + let limit_info = if let Some(limit) = &select.limit { assert!(limit.offset.is_none()); let target_register = program.alloc_register(); - let limit_reg = translate_expr(&mut program, &select, &limit.expr, target_register)?; + let limit_reg = translate_expr(&mut program, &select, &limit.expr, target_register, None)?; let num = if let ast::Expr::Literal(ast::Literal::Numeric(num)) = &limit.expr { num.parse::()? } else { @@ -79,14 +100,41 @@ fn translate_select(mut select: Select) -> Result { if !select.src_tables.is_empty() { let constraint = translate_tables_begin(&mut program, &mut select)?; - let (register_start, register_end) = translate_columns(&mut program, &select)?; + let (register_start, column_count) = if let Some(sort_columns) = select.order_by { + let start = program.next_free_register(); + for col in sort_columns.iter() { + let target = program.alloc_register(); + translate_expr(&mut program, &select, &col.expr, target, None)?; + } + let (_, result_cols_count) = translate_columns(&mut program, &select, None)?; + sort_info + .as_mut() + .map(|inner| inner.count = result_cols_count + sort_columns.len() + 1); // +1 for the key + (start, result_cols_count + sort_columns.len()) + } else { + translate_columns(&mut program, &select, None)? + }; if !select.exist_aggregation { - program.emit_insn(Insn::ResultRow { - start_reg: register_start, - count: register_end - register_start, - }); - emit_limit_insn(&limit_info, &mut program); + if let Some(ref mut sort_info) = sort_info { + let dest = program.alloc_register(); + program.emit_insn(Insn::MakeRecord { + start_reg: register_start, + count: column_count, + dest_reg: dest, + }); + program.emit_insn(Insn::SorterInsert { + cursor_id: sort_info.sorter_cursor, + record_reg: dest, + }); + sort_info.sorter_reg = register_start; + } else { + program.emit_insn(Insn::ResultRow { + start_reg: register_start, + count: column_count, + }); + emit_limit_insn(&limit_info, &mut program); + } } translate_tables_end(&mut program, &select, constraint); @@ -105,23 +153,30 @@ fn translate_select(mut select: Select) -> Result { // only one result row program.emit_insn(Insn::ResultRow { start_reg: register_start, - count: register_end - register_start, + count: column_count, }); emit_limit_insn(&limit_info, &mut program); } } else { assert!(!select.exist_aggregation); + assert!(sort_info.is_none()); let where_maybe = translate_where(&select, &mut program)?; - let (register_start, register_end) = translate_columns(&mut program, &select)?; + let (register_start, count) = translate_columns(&mut program, &select, None)?; if let Some(where_clause_label) = where_maybe { program.resolve_label(where_clause_label, program.offset() + 1); } program.emit_insn(Insn::ResultRow { start_reg: register_start, - count: register_end - register_start, + count: count, }); emit_limit_insn(&limit_info, &mut program); }; + + // now do the sort for ORDER BY + if select.order_by.is_some() { + let _ = translate_sorter(&select, &mut program, &sort_info.unwrap()); + } + program.emit_insn(Insn::Halt); let halt_offset = program.offset() - 1; if let Some(limit_info) = limit_info { @@ -155,6 +210,46 @@ fn emit_limit_insn(limit_info: &Option, program: &mut ProgramBuilder) } } +fn translate_sorter( + select: &Select, + program: &mut ProgramBuilder, + sort_info: &SortInfo, +) -> Result<()> { + assert!(sort_info.count > 0); + + let pseudo_cursor = program.alloc_cursor_id(None, None); + let pseudo_content_reg = program.alloc_register(); + program.emit_insn(Insn::OpenPseudo { + cursor_id: pseudo_cursor, + content_reg: pseudo_content_reg, + num_fields: sort_info.count, + }); + let label = program.allocate_label(); + program.emit_insn_with_label_dependency( + Insn::SorterSort { + cursor_id: sort_info.sorter_cursor, + pc_if_empty: label, + }, + label, + ); + let sorter_data_offset = program.offset(); + program.emit_insn(Insn::SorterData { + cursor_id: sort_info.sorter_cursor, + dest_reg: pseudo_content_reg, + }); + let (register_start, count) = translate_columns(program, select, Some(pseudo_cursor))?; + program.emit_insn(Insn::ResultRow { + start_reg: register_start, + count, + }); + program.emit_insn(Insn::SorterNext { + cursor_id: sort_info.sorter_cursor, + pc_if_next: sorter_data_offset, + }); + program.resolve_label(label, program.offset()); + Ok(()) +} + fn translate_tables_begin( program: &mut ProgramBuilder, select: &mut Select, @@ -164,7 +259,7 @@ fn translate_tables_begin( select.loops.push(loop_info); } - let conditions = evaluate_conditions(program, select)?; + let conditions = evaluate_conditions(program, select, None)?; for loop_info in &mut select.loops { let mut left_join_match_flag_maybe = None; @@ -181,7 +276,7 @@ fn translate_tables_begin( translate_table_open_loop(program, loop_info, left_join_match_flag_maybe); } - translate_conditions(program, select, conditions) + translate_conditions(program, select, conditions, None) } fn handle_skip_row( @@ -293,7 +388,7 @@ fn translate_table_open_cursor(program: &mut ProgramBuilder, table: &SrcTable) - Some(alias) => alias.clone(), None => table.table.get_name().to_string(), }; - let cursor_id = program.alloc_cursor_id(table_identifier, table.table.clone()); + let cursor_id = program.alloc_cursor_id(Some(table_identifier), Some(table.table.clone())); let root_page = match &table.table { Table::BTree(btree) => btree.root_page, Table::Pseudo(_) => todo!(), @@ -337,7 +432,11 @@ fn translate_table_open_loop( loop_info.rewind_offset = program.offset() - 1; } -fn translate_columns(program: &mut ProgramBuilder, select: &Select) -> Result<(usize, usize)> { +fn translate_columns( + program: &mut ProgramBuilder, + select: &Select, + cursor_hint: Option, +) -> Result<(usize, usize)> { let register_start = program.next_free_register(); // allocate one register as output for each col @@ -347,14 +446,14 @@ fn translate_columns(program: &mut ProgramBuilder, select: &Select) -> Result<(u .map(|col| col.columns_to_allocate) .sum(); program.alloc_registers(registers); - let register_end = program.next_free_register(); + let count = program.next_free_register() - register_start; let mut target = register_start; for (col, info) in select.columns.iter().zip(select.column_info.iter()) { - translate_column(program, select, col, info, target)?; + translate_column(program, select, col, info, target, cursor_hint)?; target += info.columns_to_allocate; } - Ok((register_start, register_end)) + Ok((register_start, count)) } fn translate_column( @@ -363,19 +462,27 @@ fn translate_column( col: &ast::ResultColumn, info: &ColumnInfo, target_register: usize, // where to store the result, in case of star it will be the start of registers added + cursor_hint: Option, ) -> Result<()> { match col { ast::ResultColumn::Expr(expr, _) => { if info.is_aggregation_function() { - let _ = translate_aggregation(program, select, expr, info, target_register)?; + let _ = translate_aggregation( + program, + select, + expr, + info, + target_register, + cursor_hint, + )?; } else { - let _ = translate_expr(program, select, expr, target_register)?; + let _ = translate_expr(program, select, expr, target_register, cursor_hint)?; } } ast::ResultColumn::Star => { let mut target_register = target_register; for join in &select.src_tables { - translate_table_star(join, program, target_register); + translate_table_star(join, program, target_register, cursor_hint); target_register += &join.table.columns().len(); } } @@ -384,12 +491,17 @@ fn translate_column( Ok(()) } -fn translate_table_star(table: &SrcTable, program: &mut ProgramBuilder, target_register: usize) { +fn translate_table_star( + table: &SrcTable, + program: &mut ProgramBuilder, + target_register: usize, + cursor_hint: Option, +) { let table_identifier = match table.alias { Some(alias) => alias.clone(), None => table.table.get_name().to_string(), }; - let table_cursor = program.resolve_cursor_id(&table_identifier); + let table_cursor = program.resolve_cursor_id(&table_identifier, cursor_hint); let table = &table.table; for (i, col) in table.columns().iter().enumerate() { let col_target_register = target_register + i; @@ -415,6 +527,7 @@ fn translate_aggregation( expr: &ast::Expr, info: &ColumnInfo, target_register: usize, + cursor_hint: Option, ) -> Result { let _ = expr; assert!(info.func.is_some()); @@ -430,7 +543,7 @@ fn translate_aggregation( } let expr = &args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, select, expr, expr_reg)?; + let _ = translate_expr(program, select, expr, expr_reg, cursor_hint)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -445,7 +558,7 @@ fn translate_aggregation( } else { let expr = &args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, select, expr, expr_reg); + let _ = translate_expr(program, select, expr, expr_reg, cursor_hint); expr_reg }; program.emit_insn(Insn::AggStep { @@ -486,10 +599,11 @@ fn translate_aggregation( delimiter_expr = ast::Expr::Literal(Literal::String(String::from("\",\""))); } - if let Err(error) = translate_expr(program, select, expr, expr_reg) { + if let Err(error) = translate_expr(program, select, expr, expr_reg, cursor_hint) { anyhow::bail!(error); } - if let Err(error) = translate_expr(program, select, &delimiter_expr, delimiter_reg) + if let Err(error) = + translate_expr(program, select, &delimiter_expr, delimiter_reg, cursor_hint) { anyhow::bail!(error); } @@ -509,7 +623,7 @@ fn translate_aggregation( } let expr = &args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, select, expr, expr_reg); + let _ = translate_expr(program, select, expr, expr_reg, cursor_hint); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -524,7 +638,7 @@ fn translate_aggregation( } let expr = &args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, select, expr, expr_reg); + let _ = translate_expr(program, select, expr, expr_reg, cursor_hint); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -558,10 +672,11 @@ fn translate_aggregation( _ => anyhow::bail!("Incorrect delimiter parameter"), }; - if let Err(error) = translate_expr(program, select, expr, expr_reg) { + if let Err(error) = translate_expr(program, select, expr, expr_reg, cursor_hint) { anyhow::bail!(error); } - if let Err(error) = translate_expr(program, select, &delimiter_expr, delimiter_reg) + if let Err(error) = + translate_expr(program, select, &delimiter_expr, delimiter_reg, cursor_hint) { anyhow::bail!(error); } @@ -581,7 +696,7 @@ fn translate_aggregation( } let expr = &args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, select, expr, expr_reg)?; + let _ = translate_expr(program, select, expr, expr_reg, cursor_hint)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -596,7 +711,7 @@ fn translate_aggregation( } let expr = &args[0]; let expr_reg = program.alloc_register(); - let _ = translate_expr(program, select, expr, expr_reg)?; + let _ = translate_expr(program, select, expr, expr_reg, cursor_hint)?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, diff --git a/core/vdbe.rs b/core/vdbe.rs index 40d56aea5..8eb8ab7d2 100644 --- a/core/vdbe.rs +++ b/core/vdbe.rs @@ -228,6 +228,7 @@ pub enum Insn { // Sort the rows in the sorter. SorterSort { cursor_id: CursorID, + pc_if_empty: BranchOffset, }, // Retrieve the next row from the sorter. @@ -266,7 +267,7 @@ pub struct ProgramBuilder { unresolved_labels: Vec>, next_insn_label: Option, // Cursors that are referenced by the program. Indexed by CursorID. - cursor_ref: Vec<(String, Table)>, + cursor_ref: Vec<(Option, Option)>, // List of deferred label resolutions. Each entry is a pair of (label, insn_reference). deferred_label_resolutions: Vec<(BranchOffset, InsnReference)>, } @@ -298,11 +299,23 @@ impl ProgramBuilder { reg } + pub fn drop_register(&mut self) { + self.next_free_register -= 1; + } + + pub fn drop_registers(&mut self, amount: usize) { + self.next_free_register -= amount; + } + pub fn next_free_register(&self) -> usize { self.next_free_register } - pub fn alloc_cursor_id(&mut self, table_identifier: String, table: Table) -> usize { + pub fn alloc_cursor_id( + &mut self, + table_identifier: Option, + table: Option
, + ) -> usize { let cursor = self.next_free_cursor_id; self.next_free_cursor_id += 1; self.cursor_ref.push((table_identifier, table)); @@ -318,6 +331,17 @@ impl ProgramBuilder { } } + pub fn last_insn(&self) -> Option<&Insn> { + self.insns.last() + } + + pub fn last_of_type(&self, typ: std::mem::Discriminant) -> Option<&Insn> { + self.insns + .iter() + .rev() + .find(|v| std::mem::discriminant(*v) == typ) + } + // Emit an instruction that will be put at the end of the program (after Transaction statement). // This is useful for instructions that otherwise will be unnecessarily repeated in a loop. // Example: In `SELECT * from users where name='John'`, it is unnecessary to set r[1]='John' as we SCAN users table. @@ -477,6 +501,10 @@ impl ProgramBuilder { assert!(*pc_if_next < 0); *pc_if_next = to_offset; } + Insn::SorterSort { pc_if_empty, .. } => { + assert!(*pc_if_empty < 0); + *pc_if_empty = to_offset; + } Insn::NotNull { reg: _reg, target_pc, @@ -497,10 +525,21 @@ impl ProgramBuilder { } // translate table to cursor id - pub fn resolve_cursor_id(&self, table_identifier: &str) -> CursorID { + pub fn resolve_cursor_id( + &self, + table_identifier: &str, + cursor_hint: Option, + ) -> CursorID { + if let Some(cursor_hint) = cursor_hint { + return cursor_hint; + } self.cursor_ref .iter() - .position(|(t_ident, _)| *t_ident == table_identifier) + .position(|(t_ident, _)| { + t_ident + .as_ref() + .is_some_and(|ident| ident == table_identifier) + }) .unwrap() } @@ -566,7 +605,7 @@ impl ProgramState { pub struct Program { pub max_registers: usize, pub insns: Vec, - pub cursor_ref: Vec<(String, Table)>, + pub cursor_ref: Vec<(Option, Option
)>, } impl Program { @@ -1211,10 +1250,16 @@ impl Program { cursor.insert(record)?; state.pc += 1; } - Insn::SorterSort { cursor_id } => { - let cursor = cursors.get_mut(cursor_id).unwrap(); - cursor.rewind()?; - state.pc += 1; + Insn::SorterSort { + cursor_id, + pc_if_empty, + } => { + if let Some(cursor) = cursors.get_mut(cursor_id) { + cursor.rewind()?; + state.pc += 1; + } else { + state.pc = *pc_if_empty; + } } Insn::SorterNext { cursor_id, @@ -1589,8 +1634,14 @@ fn insn_to_str(program: &Program, addr: InsnReference, insn: &Insn, indent: Stri format!( "r[{}]={}.{}", dest, - table.get_name(), - table.column_index_to_name(*column).unwrap() + table + .as_ref() + .and_then(|x| Some(x.get_name())) + .unwrap_or(format!("cursor {}", cursor_id).as_str()), + table + .as_ref() + .and_then(|x| x.column_index_to_name(*column)) + .unwrap_or(format!("column {}", *column).as_str()) ), ) } @@ -1605,7 +1656,12 @@ fn insn_to_str(program: &Program, addr: InsnReference, insn: &Insn, indent: Stri *dest_reg as i32, OwnedValue::Text(Rc::new("".to_string())), 0, - format!("r[{}..{}] -> r[{}]", start_reg, start_reg + count, dest_reg), + format!( + "r[{}]=mkrec(r[{}..{}])", + dest_reg, + start_reg, + start_reg + count - 1, + ), ), Insn::ResultRow { start_reg, count } => ( "ResultRow", @@ -1714,7 +1770,11 @@ fn insn_to_str(program: &Program, addr: InsnReference, insn: &Insn, indent: Stri format!( "r[{}]={}.rowid", dest, - &program.cursor_ref[*cursor_id].1.get_name() + &program.cursor_ref[*cursor_id] + .1 + .as_ref() + .and_then(|x| Some(x.get_name())) + .unwrap_or(format!("cursor {}", cursor_id).as_str()) ), ), Insn::DecrJumpZero { reg, target_pc } => ( @@ -1778,14 +1838,17 @@ fn insn_to_str(program: &Program, addr: InsnReference, insn: &Insn, indent: Stri *cursor_id as i32, *record_reg as i32, 0, - OwnedValue::Text(Rc::new("".to_string())), + OwnedValue::Integer(0), 0, format!("key=r[{}]", record_reg), ), - Insn::SorterSort { cursor_id } => ( + Insn::SorterSort { + cursor_id, + pc_if_empty, + } => ( "SorterSort", *cursor_id as i32, - 0, + *pc_if_empty as i32, 0, OwnedValue::Text(Rc::new("".to_string())), 0, @@ -1833,11 +1896,7 @@ fn insn_to_str(program: &Program, addr: InsnReference, insn: &Insn, indent: Stri fn get_indent_count(indent_count: usize, curr_insn: &Insn, prev_insn: Option<&Insn>) -> usize { let indent_count = if let Some(insn) = prev_insn { match insn { - Insn::RewindAwait { - cursor_id: _, - pc_if_empty: _, - } => indent_count + 1, - Insn::SorterSort { cursor_id: _ } => indent_count + 1, + Insn::RewindAwait { .. } | Insn::SorterSort { .. } => indent_count + 1, _ => indent_count, } } else { @@ -1845,11 +1904,7 @@ fn get_indent_count(indent_count: usize, curr_insn: &Insn, prev_insn: Option<&In }; match curr_insn { - Insn::NextAsync { cursor_id: _ } => indent_count - 1, - Insn::SorterNext { - cursor_id: _, - pc_if_next: _, - } => indent_count - 1, + Insn::NextAsync { .. } | Insn::SorterNext { .. } => indent_count - 1, _ => indent_count, } } diff --git a/core/where_clause.rs b/core/where_clause.rs index 225d83246..fd801523e 100644 --- a/core/where_clause.rs +++ b/core/where_clause.rs @@ -53,7 +53,7 @@ pub fn translate_where( ) -> Result> { if let Some(w) = &select.where_clause { let label = program.allocate_label(); - translate_condition_expr(program, select, w, label, false)?; + translate_condition_expr(program, select, w, label, false, None)?; Ok(Some(label)) } else { Ok(None) @@ -63,6 +63,7 @@ pub fn translate_where( pub fn evaluate_conditions( program: &mut ProgramBuilder, select: &Select, + cursor_hint: Option, ) -> Result> { let join_constraints = select .src_tables @@ -80,7 +81,12 @@ pub fn evaluate_conditions( let parsed_where_maybe = select.where_clause.as_ref().map(|where_clause| Where { constraint_expr: where_clause.clone(), no_match_jump_label: program.allocate_label(), - no_match_target_cursor: get_no_match_target_cursor(program, select, where_clause), + no_match_target_cursor: get_no_match_target_cursor( + program, + select, + where_clause, + cursor_hint, + ), }); let parsed_join_maybe = join_maybe.and_then(|(constraint, _)| { @@ -88,7 +94,12 @@ pub fn evaluate_conditions( Some(Join { constraint_expr: expr.clone(), no_match_jump_label: program.allocate_label(), - no_match_target_cursor: get_no_match_target_cursor(program, select, expr), + no_match_target_cursor: get_no_match_target_cursor( + program, + select, + expr, + cursor_hint, + ), }) } else { None @@ -155,6 +166,7 @@ pub fn translate_conditions( program: &mut ProgramBuilder, select: &Select, conditions: Option, + cursor_hint: Option, ) -> Result> { match conditions.as_ref() { Some(QueryConstraint::Left(Left { @@ -171,6 +183,7 @@ pub fn translate_conditions( &where_clause.constraint_expr, where_clause.no_match_jump_label, false, + cursor_hint, )?; } if let Some(join_clause) = join_clause { @@ -180,6 +193,7 @@ pub fn translate_conditions( &join_clause.constraint_expr, join_clause.no_match_jump_label, false, + cursor_hint, )?; } // Set match flag to 1 if we hit the marker (i.e. jump didn't happen to no_match_label as a result of the condition) @@ -197,6 +211,7 @@ pub fn translate_conditions( &where_clause.constraint_expr, where_clause.no_match_jump_label, false, + cursor_hint, )?; } if let Some(join_clause) = &inner_join.join_clause { @@ -206,6 +221,7 @@ pub fn translate_conditions( &join_clause.constraint_expr, join_clause.no_match_jump_label, false, + cursor_hint, )?; } } @@ -221,40 +237,47 @@ fn translate_condition_expr( expr: &ast::Expr, target_jump: BranchOffset, jump_if_true: bool, // if true jump to target on op == true, if false invert op + cursor_hint: Option, ) -> Result<()> { match expr { ast::Expr::Between { .. } => todo!(), ast::Expr::Binary(lhs, ast::Operator::And, rhs) => { if jump_if_true { let label = program.allocate_label(); - let _ = translate_condition_expr(program, select, lhs, label, false); - let _ = translate_condition_expr(program, select, rhs, target_jump, true); + let _ = translate_condition_expr(program, select, lhs, label, false, cursor_hint); + let _ = + translate_condition_expr(program, select, rhs, target_jump, true, cursor_hint); program.resolve_label(label, program.offset()); } else { - let _ = translate_condition_expr(program, select, lhs, target_jump, false); - let _ = translate_condition_expr(program, select, rhs, target_jump, false); + let _ = + translate_condition_expr(program, select, lhs, target_jump, false, cursor_hint); + let _ = + translate_condition_expr(program, select, rhs, target_jump, false, cursor_hint); } } ast::Expr::Binary(lhs, ast::Operator::Or, rhs) => { if jump_if_true { - let _ = translate_condition_expr(program, select, lhs, target_jump, true); - let _ = translate_condition_expr(program, select, rhs, target_jump, true); + let _ = + translate_condition_expr(program, select, lhs, target_jump, true, cursor_hint); + let _ = + translate_condition_expr(program, select, rhs, target_jump, true, cursor_hint); } else { let label = program.allocate_label(); - let _ = translate_condition_expr(program, select, lhs, label, true); - let _ = translate_condition_expr(program, select, rhs, target_jump, false); + let _ = translate_condition_expr(program, select, lhs, label, true, cursor_hint); + let _ = + translate_condition_expr(program, select, rhs, target_jump, false, cursor_hint); program.resolve_label(label, program.offset()); } } ast::Expr::Binary(lhs, op, rhs) => { let lhs_reg = program.alloc_register(); let rhs_reg = program.alloc_register(); - let _ = translate_expr(program, select, lhs, lhs_reg); + let _ = translate_expr(program, select, lhs, lhs_reg, cursor_hint); match lhs.as_ref() { ast::Expr::Literal(_) => program.mark_last_insn_constant(), _ => {} } - let _ = translate_expr(program, select, rhs, rhs_reg); + let _ = translate_expr(program, select, rhs, rhs_reg, cursor_hint); match rhs.as_ref() { ast::Expr::Literal(_) => program.mark_last_insn_constant(), _ => {} @@ -434,9 +457,9 @@ fn translate_condition_expr( let pattern_reg = program.alloc_register(); let column_reg = program.alloc_register(); // LIKE(pattern, column). We should translate the pattern first before the column - let _ = translate_expr(program, select, rhs, pattern_reg)?; + let _ = translate_expr(program, select, rhs, pattern_reg, cursor_hint)?; program.mark_last_insn_constant(); - let _ = translate_expr(program, select, lhs, column_reg)?; + let _ = translate_expr(program, select, lhs, column_reg, cursor_hint)?; program.emit_insn(Insn::Function { func: SingleRowFunc::Like, start_reg: pattern_reg, @@ -476,19 +499,31 @@ fn introspect_expression_for_cursors( program: &ProgramBuilder, select: &Select, where_expr: &ast::Expr, + cursor_hint: Option, ) -> Result> { let mut cursors = vec![]; match where_expr { ast::Expr::Binary(e1, _, e2) => { - cursors.extend(introspect_expression_for_cursors(program, select, e1)?); - cursors.extend(introspect_expression_for_cursors(program, select, e2)?); + cursors.extend(introspect_expression_for_cursors( + program, + select, + e1, + cursor_hint, + )?); + cursors.extend(introspect_expression_for_cursors( + program, + select, + e2, + cursor_hint, + )?); } ast::Expr::Id(ident) => { - let (_, _, cursor_id) = resolve_ident_table(program, &ident.0, select)?; + let (_, _, cursor_id) = resolve_ident_table(program, &ident.0, select, cursor_hint)?; cursors.push(cursor_id); } ast::Expr::Qualified(tbl, ident) => { - let (_, _, cursor_id) = resolve_ident_qualified(program, &tbl.0, &ident.0, select)?; + let (_, _, cursor_id) = + resolve_ident_qualified(program, &tbl.0, &ident.0, select, cursor_hint)?; cursors.push(cursor_id); } ast::Expr::Literal(_) => {} @@ -499,8 +534,18 @@ fn introspect_expression_for_cursors( rhs, escape, } => { - cursors.extend(introspect_expression_for_cursors(program, select, lhs)?); - cursors.extend(introspect_expression_for_cursors(program, select, rhs)?); + cursors.extend(introspect_expression_for_cursors( + program, + select, + lhs, + cursor_hint, + )?); + cursors.extend(introspect_expression_for_cursors( + program, + select, + rhs, + cursor_hint, + )?); } other => { anyhow::bail!("Parse error: unsupported expression: {:?}", other); @@ -514,12 +559,14 @@ fn get_no_match_target_cursor( program: &ProgramBuilder, select: &Select, expr: &ast::Expr, + cursor_hint: Option, ) -> usize { // This is the hackiest part of the code. We are finding the cursor that should be advanced to the next row // when the condition is not met. This is done by introspecting the expression and finding the innermost cursor that is // used in the expression. This is a very naive approach and will not work in all cases. // Thankfully though it might be possible to just refine the logic contained here to make it work in all cases. Maybe. - let cursors = introspect_expression_for_cursors(program, select, expr).unwrap_or_default(); + let cursors = + introspect_expression_for_cursors(program, select, expr, cursor_hint).unwrap_or_default(); if cursors.is_empty() { HARDCODED_CURSOR_LEFT_TABLE } else {