diff --git a/COMPAT.md b/COMPAT.md index 0c3a92ac2..47265c521 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -45,7 +45,7 @@ This document describes the SQLite compatibility status of Limbo: | SELECT ... WHERE ... LIKE | Yes | | | SELECT ... LIMIT | Yes | | | SELECT ... ORDER BY | Partial | | -| SELECT ... GROUP BY | No | | +| SELECT ... GROUP BY | Partial | | | SELECT ... JOIN | Partial | | | SELECT ... CROSS JOIN | Partial | | | SELECT ... INNER JOIN | Partial | | diff --git a/core/lib.rs b/core/lib.rs index c0f42e574..164b299c8 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -198,7 +198,7 @@ impl Connection { match stmt { ast::Stmt::Select(select) => { let plan = prepare_select_plan(&self.schema, select)?; - let plan = optimize_plan(plan)?; + let (plan, _) = optimize_plan(plan)?; println!("{}", plan); } _ => todo!(), diff --git a/core/pseudo.rs b/core/pseudo.rs index 7f4d94e08..431fc83d1 100644 --- a/core/pseudo.rs +++ b/core/pseudo.rs @@ -41,7 +41,9 @@ impl Cursor for PseudoCursor { .as_ref() .map(|record| match record.values[0] { OwnedValue::Integer(rowid) => rowid as u64, - _ => panic!("Expected integer value"), + ref ov => { + panic!("Expected integer value, got {:?}", ov); + } }); Ok(x) } diff --git a/core/schema.rs b/core/schema.rs index fc3bf6603..f98348dd8 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -95,6 +95,13 @@ impl Table { } } + pub fn get_column_at(&self, index: usize) -> &Column { + match self { + Table::BTree(table) => table.columns.get(index).unwrap(), + Table::Pseudo(table) => table.columns.get(index).unwrap(), + } + } + pub fn columns(&self) -> &Vec { match self { Table::BTree(table) => &table.columns, diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index cac0aafe3..cba3f46c4 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -3,9 +3,12 @@ use std::collections::HashMap; use std::rc::Rc; use std::usize; +use sqlite3_parser::ast; + use crate::schema::{BTreeTable, Column, PseudoTable, Table}; use crate::storage::sqlite3_ondisk::DatabaseHeader; use crate::types::{OwnedRecord, OwnedValue}; +use crate::util::normalize_ident; use crate::vdbe::builder::ProgramBuilder; use crate::vdbe::{BranchOffset, Insn, Program}; use crate::Result; @@ -14,6 +17,7 @@ use super::expr::{ translate_aggregation, translate_condition_expr, translate_expr, translate_table_columns, ConditionMetadata, }; +use super::optimizer::ExpressionResultCache; use super::plan::Plan; use super::plan::{Operator, ProjectionColumn}; @@ -35,14 +39,14 @@ pub trait Emitter { program: &mut ProgramBuilder, referenced_tables: &[(Rc, String)], metadata: &mut Metadata, - cursor_override: Option, + cursor_override: Option<&SortCursorOverride>, ) -> Result; fn result_row( &mut self, program: &mut ProgramBuilder, referenced_tables: &[(Rc, String)], metadata: &mut Metadata, - cursor_override: Option, + cursor_override: Option<&SortCursorOverride>, ) -> Result<()>; } @@ -68,12 +72,56 @@ pub struct SortMetadata { pub sorter_data_label: BranchOffset, // label for the instruction immediately following SorterNext; SorterSort will jump here in case there is no data pub done_label: BranchOffset, + // register where the sorter data is inserted and later retrieved from + pub sorter_data_register: usize, } -#[derive(Debug, Default)] +#[derive(Debug)] +pub struct GroupByMetadata { + // Cursor ID for the Sorter table where the grouped rows are stored + pub sort_cursor: usize, + // Label for the subroutine that clears the accumulator registers (temporary storage for per-group aggregate calculations) + pub subroutine_accumulator_clear_label: BranchOffset, + // Register holding the return offset for the accumulator clear subroutine + pub subroutine_accumulator_clear_return_offset_register: usize, + // Label for the subroutine that outputs the accumulator contents + pub subroutine_accumulator_output_label: BranchOffset, + // Register holding the return offset for the accumulator output subroutine + pub subroutine_accumulator_output_return_offset_register: usize, + // Label for the instruction that sets the accumulator indicator to true (indicating data exists in the accumulator for the current group) + pub accumulator_indicator_set_true_label: BranchOffset, + // Label for the instruction where SorterData is emitted (used for fetching sorted data) + pub sorter_data_label: BranchOffset, + // Register holding the key used for sorting in the Sorter + pub sorter_key_register: usize, + // Label for the instruction signaling the completion of grouping operations + pub grouping_done_label: BranchOffset, + // Register holding a flag to abort the grouping process if necessary + pub abort_flag_register: usize, + // Register holding a boolean indicating whether there's data in the accumulator (used for aggregation) + pub data_in_accumulator_indicator_register: usize, + // Register holding the start of the accumulator group registers (i.e. the groups, not the aggregates) + pub group_exprs_accumulator_register: usize, + // Starting index of the register(s) that hold the comparison result between the current row and the previous row + // The comparison result is used to determine if the current row belongs to the same group as the previous row + // Each group by expression has a corresponding register + pub group_exprs_comparison_register: usize, +} + +#[derive(Debug)] +pub struct SortCursorOverride { + pub cursor_id: usize, + pub pseudo_table: Table, + pub sort_key_len: usize, +} + +/// The Metadata struct holds various information and labels used during bytecode generation. +/// It is used for maintaining state and control flow during the bytecode +/// generation process. +#[derive(Debug)] pub struct Metadata { // labels for the instructions that terminate the execution when a conditional check evaluates to false. typically jumps to Halt, but can also jump to AggFinal if a parent in the tree is an aggregation - termination_labels: Vec, + termination_label_stack: Vec, // labels for the instructions that jump to the next row in the current operator. // for example, in a join with two nested scans, the inner loop will jump to its Next instruction when the join condition is false; // in a join with a scan and a seek, the seek will jump to the scan's Next instruction when the join condition is false. @@ -82,28 +130,27 @@ pub struct Metadata { rewind_labels: Vec, // mapping between Aggregation operator id and the register that holds the start of the aggregation result aggregation_start_registers: HashMap, + // mapping between Aggregation operator id and associated metadata (if the aggregation has a group by clause) + group_bys: HashMap, // mapping between Order operator id and associated metadata sorts: HashMap, // mapping between Join operator id and associated metadata (for left joins only) left_joins: HashMap, + expr_result_cache: ExpressionResultCache, } -/** -* Emitters return one of three possible results from the step() method: -* - Continue: the operator is not yet ready to emit a result row -* - ReadyToEmit: the operator is ready to emit a result row -* - Done: the operator has completed execution -* For example, a Scan operator will return Continue until it has opened a cursor, rewound it and applied any predicates. -* At that point, it will return ReadyToEmit. -* Finally, when the Scan operator has emitted a Next instruction, it will return Done. -* -* Parent operators are free to make decisions based on the result a child operator's step() method. -* -* When the root operator of a Plan returns ReadyToEmit, a ResultRow will always be emitted. -* When the root operator returns Done, the bytecode plan is complete. -* - -*/ +/// Emitters return one of three possible results from the step() method: +/// - Continue: the operator is not yet ready to emit a result row +/// - ReadyToEmit: the operator is ready to emit a result row +/// - Done: the operator has completed execution +/// For example, a Scan operator will return Continue until it has opened a cursor, rewound it and applied any predicates. +/// At that point, it will return ReadyToEmit. +/// Finally, when the Scan operator has emitted a Next instruction, it will return Done. +/// +/// Parent operators are free to make decisions based on the result a child operator's step() method. +/// +/// When the root operator of a Plan returns ReadyToEmit, a ResultRow will always be emitted. +/// When the root operator returns Done, the bytecode plan is complete. #[derive(Debug, PartialEq)] pub enum OpStepResult { Continue, @@ -118,6 +165,7 @@ impl Emitter for Operator { m: &mut Metadata, referenced_tables: &[(Rc, String)], ) -> Result { + let current_operator_column_count = self.column_count(referenced_tables); match self { Operator::Scan { table, @@ -152,7 +200,7 @@ impl Emitter for Operator { let cursor_id = program.resolve_cursor_id(table_identifier, None); program.emit_insn(Insn::RewindAsync { cursor_id }); let rewind_label = program.allocate_label(); - let halt_label = m.termination_labels.last().unwrap(); + let halt_label = m.termination_label_stack.last().unwrap(); m.rewind_labels.push(rewind_label); program.defer_label_resolution(rewind_label, program.offset() as usize); program.emit_insn_with_label_dependency( @@ -239,11 +287,12 @@ impl Emitter for Operator { rowid_predicate, rowid_reg, None, + None, )?; let jump_label = m .next_row_labels .get(id) - .unwrap_or(&m.termination_labels.last().unwrap()); + .unwrap_or(&m.termination_label_stack.last().unwrap()); program.emit_insn_with_label_dependency( Insn::SeekRowid { cursor_id, @@ -312,7 +361,7 @@ impl Emitter for Operator { .next_row_labels .get(&right.id()) .or(m.next_row_labels.get(&left.id())) - .unwrap_or(&m.termination_labels.last().unwrap()); + .unwrap_or(&m.termination_label_stack.last().unwrap()); if *outer { let lj_meta = m.left_joins.get(id).unwrap(); @@ -408,15 +457,490 @@ impl Emitter for Operator { id, source, aggregates, + group_by, step, + .. } => { *step += 1; + + // Group by aggregation eg. SELECT a, b, sum(c) FROM t GROUP BY a, b + if let Some(group_by) = group_by { + const GROUP_BY_INIT: usize = 1; + const GROUP_BY_INSERT_INTO_SORTER: usize = 2; + const GROUP_BY_SORT_AND_COMPARE: usize = 3; + const GROUP_BY_PREPARE_ROW: usize = 4; + const GROUP_BY_CLEAR_ACCUMULATOR_SUBROUTINE: usize = 5; + match *step { + GROUP_BY_INIT => { + let agg_final_label = program.allocate_label(); + m.termination_label_stack.push(agg_final_label); + let num_aggs = aggregates.len(); + + let sort_cursor = program.alloc_cursor_id(None, None); + + let abort_flag_register = program.alloc_register(); + let data_in_accumulator_indicator_register = program.alloc_register(); + let group_exprs_comparison_register = + program.alloc_registers(group_by.len()); + let group_exprs_accumulator_register = + program.alloc_registers(group_by.len()); + let agg_exprs_start_reg = program.alloc_registers(num_aggs); + m.aggregation_start_registers + .insert(*id, agg_exprs_start_reg); + let sorter_key_register = program.alloc_register(); + + let subroutine_accumulator_clear_label = program.allocate_label(); + let subroutine_accumulator_output_label = program.allocate_label(); + let sorter_data_label = program.allocate_label(); + let grouping_done_label = program.allocate_label(); + + let mut order = Vec::new(); + const ASCENDING: i64 = 0; + for _ in group_by.iter() { + order.push(OwnedValue::Integer(ASCENDING as i64)); + } + program.emit_insn(Insn::SorterOpen { + cursor_id: sort_cursor, + columns: current_operator_column_count, + order: OwnedRecord::new(order), + }); + + program.add_comment(program.offset(), "clear group by abort flag"); + program.emit_insn(Insn::Integer { + value: 0, + dest: abort_flag_register, + }); + + program.add_comment( + program.offset(), + "initialize group by comparison registers to NULL", + ); + program.emit_insn(Insn::Null { + dest: group_exprs_comparison_register, + dest_end: if group_by.len() > 1 { + Some(group_exprs_comparison_register + group_by.len() - 1) + } else { + None + }, + }); + + program.add_comment( + program.offset(), + "go to clear accumulator subroutine", + ); + + let subroutine_accumulator_clear_return_offset_register = + program.alloc_register(); + program.emit_insn_with_label_dependency( + Insn::Gosub { + target_pc: subroutine_accumulator_clear_label, + return_reg: subroutine_accumulator_clear_return_offset_register, + }, + subroutine_accumulator_clear_label, + ); + + m.group_bys.insert( + *id, + GroupByMetadata { + sort_cursor, + subroutine_accumulator_clear_label, + subroutine_accumulator_clear_return_offset_register, + subroutine_accumulator_output_label, + subroutine_accumulator_output_return_offset_register: program + .alloc_register(), + accumulator_indicator_set_true_label: program.allocate_label(), + sorter_data_label, + grouping_done_label, + abort_flag_register, + data_in_accumulator_indicator_register, + group_exprs_accumulator_register, + group_exprs_comparison_register, + sorter_key_register, + }, + ); + + loop { + match source.step(program, m, referenced_tables)? { + OpStepResult::Continue => continue, + OpStepResult::ReadyToEmit => { + return Ok(OpStepResult::Continue); + } + OpStepResult::Done => { + return Ok(OpStepResult::Done); + } + } + } + } + GROUP_BY_INSERT_INTO_SORTER => { + let sort_keys_count = group_by.len(); + let start_reg = program.alloc_registers(current_operator_column_count); + for (i, expr) in group_by.iter().enumerate() { + let key_reg = start_reg + i; + translate_expr( + program, + Some(referenced_tables), + expr, + key_reg, + None, + None, + )?; + } + for (i, agg) in aggregates.iter().enumerate() { + let expr = &agg.args[0]; // TODO hakhackhachkachkachk hack hack + let agg_reg = start_reg + sort_keys_count + i; + translate_expr( + program, + Some(referenced_tables), + expr, + agg_reg, + None, + None, + )?; + } + + let group_by_metadata = m.group_bys.get(id).unwrap(); + + program.emit_insn(Insn::MakeRecord { + start_reg, + count: current_operator_column_count, + dest_reg: group_by_metadata.sorter_key_register, + }); + + let group_by_metadata = m.group_bys.get(id).unwrap(); + program.emit_insn(Insn::SorterInsert { + cursor_id: group_by_metadata.sort_cursor, + record_reg: group_by_metadata.sorter_key_register, + }); + + return Ok(OpStepResult::Continue); + } + GROUP_BY_SORT_AND_COMPARE => { + loop { + match source.step(program, m, referenced_tables)? { + OpStepResult::Done => { + break; + } + _ => unreachable!(), + } + } + + let group_by_metadata = m.group_bys.get_mut(id).unwrap(); + + let GroupByMetadata { + group_exprs_comparison_register: comparison_register, + subroutine_accumulator_output_return_offset_register, + subroutine_accumulator_output_label, + subroutine_accumulator_clear_return_offset_register, + subroutine_accumulator_clear_label, + data_in_accumulator_indicator_register, + accumulator_indicator_set_true_label, + group_exprs_accumulator_register: group_exprs_start_register, + abort_flag_register, + sorter_key_register, + .. + } = *group_by_metadata; + let halt_label = *m.termination_label_stack.first().unwrap(); + + let mut column_names = + Vec::with_capacity(current_operator_column_count); + for expr in group_by + .iter() + .chain(aggregates.iter().map(|agg| &agg.args[0])) + // FIXME: just blindly taking the first arg is a hack + { + // FIXME: reading from pseudo tables made during sort operations + // now relies on them having the same column names as the original + // table. This is not very robust IMO and we should refactor how these + // are handled. + column_names.push(match expr { + ast::Expr::Id(ident) => normalize_ident(&ident.0), + ast::Expr::Qualified(tbl, ident) => { + format!( + "{}.{}", + normalize_ident(&tbl.0), + normalize_ident(&ident.0) + ) + } + _ => "expr".to_string(), + }); + } + let pseudo_columns = column_names + .iter() + .map(|name| Column { + name: name.clone(), + primary_key: false, + ty: crate::schema::Type::Null, + }) + .collect::>(); + + let pseudo_cursor = program.alloc_cursor_id( + None, + Some(Table::Pseudo(Rc::new(PseudoTable { + columns: pseudo_columns, + }))), + ); + + program.emit_insn(Insn::OpenPseudo { + cursor_id: pseudo_cursor, + content_reg: sorter_key_register, + num_fields: current_operator_column_count, + }); + + let group_by_metadata = m.group_bys.get(id).unwrap(); + program.emit_insn_with_label_dependency( + Insn::SorterSort { + cursor_id: group_by_metadata.sort_cursor, + pc_if_empty: group_by_metadata.grouping_done_label, + }, + group_by_metadata.grouping_done_label, + ); + + program.defer_label_resolution( + group_by_metadata.sorter_data_label, + program.offset() as usize, + ); + program.emit_insn(Insn::SorterData { + cursor_id: group_by_metadata.sort_cursor, + dest_reg: group_by_metadata.sorter_key_register, + pseudo_cursor, + }); + + let groups_start_reg = program.alloc_registers(group_by.len()); + for (i, expr) in group_by.iter().enumerate() { + let group_reg = groups_start_reg + i; + translate_expr( + program, + Some(referenced_tables), + expr, + group_reg, + Some(pseudo_cursor), + None, + )?; + } + + program.emit_insn(Insn::Compare { + start_reg_a: comparison_register, + start_reg_b: groups_start_reg, + count: group_by.len(), + }); + + let agg_step_label = program.allocate_label(); + + program.add_comment( + program.offset(), + "start new group if comparison is not equal", + ); + program.emit_insn_with_label_dependency( + Insn::Jump { + target_pc_lt: program.offset() + 1, + target_pc_eq: agg_step_label, + target_pc_gt: program.offset() + 1, + }, + agg_step_label, + ); + + program.emit_insn(Insn::Move { + source_reg: groups_start_reg, + dest_reg: comparison_register, + count: group_by.len(), + }); + + program.add_comment( + program.offset(), + "check if ended group had data, and output if so", + ); + program.emit_insn_with_label_dependency( + Insn::Gosub { + target_pc: subroutine_accumulator_output_label, + return_reg: + subroutine_accumulator_output_return_offset_register, + }, + subroutine_accumulator_output_label, + ); + + program.add_comment(program.offset(), "check abort flag"); + program.emit_insn_with_label_dependency( + Insn::IfPos { + reg: abort_flag_register, + target_pc: halt_label, + decrement_by: 0, + }, + m.termination_label_stack[0], + ); + + program + .add_comment(program.offset(), "goto clear accumulator subroutine"); + program.emit_insn_with_label_dependency( + Insn::Gosub { + target_pc: subroutine_accumulator_clear_label, + return_reg: subroutine_accumulator_clear_return_offset_register, + }, + subroutine_accumulator_clear_label, + ); + + program.resolve_label(agg_step_label, program.offset()); + let start_reg = m.aggregation_start_registers.get(id).unwrap(); + for (i, agg) in aggregates.iter().enumerate() { + let agg_result_reg = start_reg + i; + translate_aggregation( + program, + referenced_tables, + agg, + agg_result_reg, + Some(pseudo_cursor), + )?; + } + + program.add_comment( + program.offset(), + "don't emit group columns if continuing existing group", + ); + program.emit_insn_with_label_dependency( + Insn::If { + target_pc: accumulator_indicator_set_true_label, + reg: data_in_accumulator_indicator_register, + null_reg: 0, // unused in this case + }, + accumulator_indicator_set_true_label, + ); + + for (i, expr) in group_by.iter().enumerate() { + let key_reg = group_exprs_start_register + i; + translate_expr( + program, + Some(referenced_tables), + expr, + key_reg, + Some(pseudo_cursor), + None, + )?; + } + + program.resolve_label( + accumulator_indicator_set_true_label, + program.offset(), + ); + program.add_comment(program.offset(), "indicate data in accumulator"); + program.emit_insn(Insn::Integer { + value: 1, + dest: data_in_accumulator_indicator_register, + }); + + return Ok(OpStepResult::Continue); + } + GROUP_BY_PREPARE_ROW => { + let group_by_metadata = m.group_bys.get(id).unwrap(); + program.emit_insn_with_label_dependency( + Insn::SorterNext { + cursor_id: group_by_metadata.sort_cursor, + pc_if_next: group_by_metadata.sorter_data_label, + }, + group_by_metadata.sorter_data_label, + ); + + program.resolve_label( + group_by_metadata.grouping_done_label, + program.offset(), + ); + + program.add_comment(program.offset(), "emit row for final group"); + program.emit_insn_with_label_dependency( + Insn::Gosub { + target_pc: group_by_metadata + .subroutine_accumulator_output_label, + return_reg: group_by_metadata + .subroutine_accumulator_output_return_offset_register, + }, + group_by_metadata.subroutine_accumulator_output_label, + ); + + program.add_comment(program.offset(), "group by finished"); + let termination_label = + m.termination_label_stack[m.termination_label_stack.len() - 2]; + program.emit_insn_with_label_dependency( + Insn::Goto { + target_pc: termination_label, + }, + termination_label, + ); + program.emit_insn(Insn::Integer { + value: 1, + dest: group_by_metadata.abort_flag_register, + }); + program.emit_insn(Insn::Return { + return_reg: group_by_metadata + .subroutine_accumulator_output_return_offset_register, + }); + + program.resolve_label( + group_by_metadata.subroutine_accumulator_output_label, + program.offset(), + ); + + program.add_comment( + program.offset(), + "output group by row subroutine start", + ); + let termination_label = *m.termination_label_stack.last().unwrap(); + program.emit_insn_with_label_dependency( + Insn::IfPos { + reg: group_by_metadata.data_in_accumulator_indicator_register, + target_pc: termination_label, + decrement_by: 0, + }, + termination_label, + ); + program.emit_insn(Insn::Return { + return_reg: group_by_metadata + .subroutine_accumulator_output_return_offset_register, + }); + + return Ok(OpStepResult::ReadyToEmit); + } + GROUP_BY_CLEAR_ACCUMULATOR_SUBROUTINE => { + let group_by_metadata = m.group_bys.get(id).unwrap(); + program.emit_insn(Insn::Return { + return_reg: group_by_metadata + .subroutine_accumulator_output_return_offset_register, + }); + + program.add_comment( + program.offset(), + "clear accumulator subroutine start", + ); + program.resolve_label( + group_by_metadata.subroutine_accumulator_clear_label, + program.offset(), + ); + let start_reg = group_by_metadata.group_exprs_accumulator_register; + program.emit_insn(Insn::Null { + dest: start_reg, + dest_end: Some(start_reg + group_by.len() + aggregates.len() - 1), + }); + + program.emit_insn(Insn::Integer { + value: 0, + dest: group_by_metadata.data_in_accumulator_indicator_register, + }); + program.emit_insn(Insn::Return { + return_reg: group_by_metadata + .subroutine_accumulator_clear_return_offset_register, + }); + } + _ => { + return Ok(OpStepResult::Done); + } + } + } + + // Non-grouped aggregation e.g. SELECT COUNT(*) FROM t + const AGGREGATE_INIT: usize = 1; const AGGREGATE_WAIT_UNTIL_SOURCE_READY: usize = 2; match *step { AGGREGATE_INIT => { let agg_final_label = program.allocate_label(); - m.termination_labels.push(agg_final_label); + m.termination_label_stack.push(agg_final_label); let num_aggs = aggregates.len(); let start_reg = program.alloc_registers(num_aggs); m.aggregation_start_registers.insert(*id, start_reg); @@ -473,12 +997,14 @@ impl Emitter for Operator { const ORDER_NEXT: usize = 4; match *step { ORDER_INIT => { + m.termination_label_stack.push(program.allocate_label()); let sort_cursor = program.alloc_cursor_id(None, None); m.sorts.insert( *id, SortMetadata { sort_cursor, pseudo_table_cursor: usize::MAX, // will be set later + sorter_data_register: program.alloc_register(), sorter_data_label: program.allocate_label(), done_label: program.allocate_label(), }, @@ -509,23 +1035,32 @@ impl Emitter for Operator { let sort_keys_count = key.len(); let source_cols_count = source.column_count(referenced_tables); let start_reg = program.alloc_registers(sort_keys_count); - for (i, (expr, _)) in key.iter().enumerate() { - let key_reg = start_reg + i; - translate_expr(program, Some(referenced_tables), expr, key_reg, None)?; - } source.result_columns(program, referenced_tables, m, None)?; - let dest = program.alloc_register(); + for (i, (expr, _)) in key.iter().enumerate() { + let key_reg = start_reg + i; + translate_expr( + program, + Some(referenced_tables), + expr, + key_reg, + None, + m.expr_result_cache + .get_cached_result_registers(*id, i) + .as_ref(), + )?; + } + + let sort_metadata = m.sorts.get_mut(id).unwrap(); program.emit_insn(Insn::MakeRecord { start_reg, count: sort_keys_count + source_cols_count, - dest_reg: dest, + dest_reg: sort_metadata.sorter_data_register, }); - let sort_metadata = m.sorts.get_mut(id).unwrap(); program.emit_insn(Insn::SorterInsert { cursor_id: sort_metadata.sort_cursor, - record_reg: dest, + record_reg: sort_metadata.sorter_data_register, }); Ok(OpStepResult::Continue) @@ -539,15 +1074,28 @@ impl Emitter for Operator { _ => unreachable!(), } } + program.resolve_label( + m.termination_label_stack.pop().unwrap(), + program.offset(), + ); let column_names = source.column_names(); - let pseudo_columns = column_names - .iter() - .map(|name| Column { + let mut pseudo_columns = vec![]; + for (i, _) in key.iter().enumerate() { + pseudo_columns.push(Column { + name: format!("sort_key_{}", i), + primary_key: false, + ty: crate::schema::Type::Null, + }); + } + for name in column_names { + pseudo_columns.push(Column { name: name.clone(), primary_key: false, ty: crate::schema::Type::Null, - }) - .collect::>(); + }); + } + + let num_fields = pseudo_columns.len(); let pseudo_cursor = program.alloc_cursor_id( None, @@ -555,15 +1103,14 @@ impl Emitter for Operator { columns: pseudo_columns, }))), ); + let sort_metadata = m.sorts.get(id).unwrap(); - let pseudo_content_reg = program.alloc_register(); program.emit_insn(Insn::OpenPseudo { cursor_id: pseudo_cursor, - content_reg: pseudo_content_reg, - num_fields: key.len() + source.column_count(referenced_tables), + content_reg: sort_metadata.sorter_data_register, + num_fields, }); - let sort_metadata = m.sorts.get(id).unwrap(); program.emit_insn_with_label_dependency( Insn::SorterSort { cursor_id: sort_metadata.sort_cursor, @@ -578,7 +1125,7 @@ impl Emitter for Operator { ); program.emit_insn(Insn::SorterData { cursor_id: sort_metadata.sort_cursor, - dest_reg: pseudo_content_reg, + dest_reg: sort_metadata.sorter_data_register, pseudo_cursor, }); @@ -614,6 +1161,9 @@ impl Emitter for Operator { match source.step(program, m, referenced_tables)? { OpStepResult::Continue => continue, OpStepResult::ReadyToEmit | OpStepResult::Done => { + if matches!(**source, Operator::Aggregate { .. }) { + source.result_columns(program, referenced_tables, m, None)?; + } return Ok(OpStepResult::ReadyToEmit); } } @@ -637,7 +1187,7 @@ impl Emitter for Operator { program: &mut ProgramBuilder, referenced_tables: &[(Rc, String)], m: &mut Metadata, - cursor_override: Option, + cursor_override: Option<&SortCursorOverride>, ) -> Result { let col_count = self.column_count(referenced_tables); match self { @@ -647,13 +1197,14 @@ impl Emitter for Operator { .. } => { let start_reg = program.alloc_registers(col_count); - translate_table_columns( - program, - table, - table_identifier, - cursor_override, - start_reg, - ); + let table = cursor_override + .map(|c| c.pseudo_table.clone()) + .unwrap_or_else(|| Table::BTree(table.clone())); + let cursor_id = cursor_override + .map(|c| c.cursor_id) + .unwrap_or_else(|| program.resolve_cursor_id(table_identifier, None)); + let start_column_offset = cursor_override.map(|c| c.sort_key_len).unwrap_or(0); + translate_table_columns(program, cursor_id, &table, start_column_offset, start_reg); Ok(start_reg) } @@ -664,17 +1215,57 @@ impl Emitter for Operator { Ok(left_start_reg) } - Operator::Aggregate { id, aggregates, .. } => { - let start_reg = m.aggregation_start_registers.get(id).unwrap(); + Operator::Aggregate { + id, + aggregates, + group_by, + .. + } => { + let agg_start_reg = m.aggregation_start_registers.get(id).unwrap(); + program.resolve_label(m.termination_label_stack.pop().unwrap(), program.offset()); + let mut result_column_idx = 0; for (i, agg) in aggregates.iter().enumerate() { - let agg_result_reg = *start_reg + i; + let agg_result_reg = *agg_start_reg + i; program.emit_insn(Insn::AggFinal { register: agg_result_reg, func: agg.func.clone(), }); + m.expr_result_cache.cache_result_register( + *id, + result_column_idx, + agg_result_reg, + agg.original_expr.clone(), + ); + result_column_idx += 1; } - Ok(*start_reg) + if let Some(group_by) = group_by { + let output_row_start_reg = + program.alloc_registers(aggregates.len() + group_by.len()); + let group_by_metadata = m.group_bys.get(id).unwrap(); + program.emit_insn(Insn::Copy { + src_reg: group_by_metadata.group_exprs_accumulator_register, + dst_reg: output_row_start_reg, + amount: group_by.len() - 1, + }); + for (i, source_expr) in group_by.iter().enumerate() { + m.expr_result_cache.cache_result_register( + *id, + result_column_idx + i, + output_row_start_reg + i, + source_expr.clone(), + ); + } + program.emit_insn(Insn::Copy { + src_reg: *agg_start_reg, + dst_reg: output_row_start_reg + group_by.len(), + amount: aggregates.len() - 1, + }); + + Ok(output_row_start_reg) + } else { + Ok(*agg_start_reg) + } } Operator::Filter { .. } => unreachable!("predicates have been pushed down"), Operator::SeekRowid { @@ -683,41 +1274,39 @@ impl Emitter for Operator { .. } => { let start_reg = program.alloc_registers(col_count); - translate_table_columns( - program, - table, - table_identifier, - cursor_override, - start_reg, - ); + let table = cursor_override + .map(|c| c.pseudo_table.clone()) + .unwrap_or_else(|| Table::BTree(table.clone())); + let cursor_id = cursor_override + .map(|c| c.cursor_id) + .unwrap_or_else(|| program.resolve_cursor_id(table_identifier, None)); + let start_column_offset = cursor_override.map(|c| c.sort_key_len).unwrap_or(0); + translate_table_columns(program, cursor_id, &table, start_column_offset, start_reg); Ok(start_reg) } Operator::Limit { .. } => { unimplemented!() } - Operator::Order { - id, source, key, .. - } => { - let sort_metadata = m.sorts.get(id).unwrap(); - let cursor_override = Some(sort_metadata.sort_cursor); - let sort_keys_count = key.len(); - let start_reg = program.alloc_registers(sort_keys_count); - for (i, (expr, _)) in key.iter().enumerate() { - let key_reg = start_reg + i; - translate_expr( - program, - Some(referenced_tables), - expr, - key_reg, - cursor_override, - )?; - } - source.result_columns(program, referenced_tables, m, cursor_override)?; + Operator::Order { id, key, .. } => { + let cursor_id = m.sorts.get(id).unwrap().pseudo_table_cursor; + let pseudo_table = program.resolve_cursor_to_table(cursor_id).unwrap(); + let start_column_offset = key.len(); + let column_count = pseudo_table.columns().len() - start_column_offset; + let start_reg = program.alloc_registers(column_count); + translate_table_columns( + program, + cursor_id, + &pseudo_table, + start_column_offset, + start_reg, + ); Ok(start_reg) } - Operator::Projection { expressions, .. } => { + Operator::Projection { + expressions, id, .. + } => { let expr_count = expressions .iter() .map(|e| e.column_count(referenced_tables)) @@ -732,17 +1321,35 @@ impl Emitter for Operator { Some(referenced_tables), expr, cur_reg, - cursor_override, + cursor_override.map(|c| c.cursor_id), + m.expr_result_cache + .get_cached_result_registers(*id, cur_reg - start_reg) + .as_ref(), )?; + m.expr_result_cache.cache_result_register( + *id, + cur_reg - start_reg, + cur_reg, + expr.clone(), + ); cur_reg += 1; } ProjectionColumn::Star => { for (table, table_identifier) in referenced_tables.iter() { + let table = cursor_override + .map(|c| c.pseudo_table.clone()) + .unwrap_or_else(|| Table::BTree(table.clone())); + let cursor_id = + cursor_override.map(|c| c.cursor_id).unwrap_or_else(|| { + program.resolve_cursor_id(table_identifier, None) + }); + let start_column_offset = + cursor_override.map(|c| c.sort_key_len).unwrap_or(0); cur_reg = translate_table_columns( program, - table, - table_identifier, - cursor_override, + cursor_id, + &table, + start_column_offset, cur_reg, ); } @@ -752,11 +1359,21 @@ impl Emitter for Operator { .iter() .find(|(_, id)| id == table_identifier) .unwrap(); + + let table = cursor_override + .map(|c| c.pseudo_table.clone()) + .unwrap_or_else(|| Table::BTree(table.clone())); + let cursor_id = + cursor_override.map(|c| c.cursor_id).unwrap_or_else(|| { + program.resolve_cursor_id(table_identifier, None) + }); + let start_column_offset = + cursor_override.map(|c| c.sort_key_len).unwrap_or(0); cur_reg = translate_table_columns( program, - table, - table_identifier, - cursor_override, + cursor_id, + &table, + start_column_offset, cur_reg, ); } @@ -773,20 +1390,9 @@ impl Emitter for Operator { program: &mut ProgramBuilder, referenced_tables: &[(Rc, String)], m: &mut Metadata, - cursor_override: Option, + cursor_override: Option<&SortCursorOverride>, ) -> Result<()> { match self { - Operator::Order { id, source, .. } => { - let sort_metadata = m.sorts.get(id).unwrap(); - source.result_row( - program, - referenced_tables, - m, - Some(sort_metadata.pseudo_table_cursor), - )?; - - Ok(()) - } Operator::Limit { source, limit, .. } => { source.result_row(program, referenced_tables, m, cursor_override)?; let limit_reg = program.alloc_register(); @@ -795,7 +1401,7 @@ impl Emitter for Operator { dest: limit_reg, }); program.mark_last_insn_constant(); - let jump_label = m.termination_labels.last().unwrap(); + let jump_label = m.termination_label_stack.first().unwrap(); program.emit_insn_with_label_dependency( Insn::DecrJumpZero { reg: limit_reg, @@ -819,7 +1425,9 @@ impl Emitter for Operator { } } -fn prologue() -> Result<( +fn prologue( + cache: ExpressionResultCache, +) -> Result<( ProgramBuilder, Metadata, BranchOffset, @@ -840,8 +1448,14 @@ fn prologue() -> Result<( let start_offset = program.offset(); let metadata = Metadata { - termination_labels: vec![halt_label], - ..Default::default() + termination_label_stack: vec![halt_label], + expr_result_cache: cache, + aggregation_start_registers: HashMap::new(), + group_bys: HashMap::new(), + left_joins: HashMap::new(), + next_row_labels: HashMap::new(), + rewind_labels: vec![], + sorts: HashMap::new(), }; Ok((program, metadata, init_label, halt_label, start_offset)) @@ -872,9 +1486,9 @@ fn epilogue( pub fn emit_program( database_header: Rc>, mut plan: Plan, + cache: ExpressionResultCache, ) -> Result { - let (mut program, mut metadata, init_label, halt_label, start_offset) = prologue()?; - + let (mut program, mut metadata, init_label, halt_label, start_offset) = prologue(cache)?; loop { match plan .root_operator diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 6e18c8605..96c49f505 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -2,8 +2,9 @@ use crate::{function::JsonFunc, Result}; use sqlite3_parser::ast::{self, UnaryOperator}; use std::rc::Rc; +use super::optimizer::CachedResult; use crate::function::{AggFunc, Func, FuncCtx, ScalarFunc}; -use crate::schema::Type; +use crate::schema::{Table, Type}; use crate::util::normalize_ident; use crate::{ schema::BTreeTable, @@ -74,13 +75,27 @@ pub fn translate_condition_expr( } ast::Expr::Binary(lhs, op, rhs) => { let lhs_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), lhs, lhs_reg, cursor_hint); + let _ = translate_expr( + program, + Some(referenced_tables), + lhs, + lhs_reg, + cursor_hint, + None, + ); match lhs.as_ref() { ast::Expr::Literal(_) => program.mark_last_insn_constant(), _ => {} } let rhs_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), rhs, rhs_reg, cursor_hint); + let _ = translate_expr( + program, + Some(referenced_tables), + rhs, + rhs_reg, + cursor_hint, + None, + ); match rhs.as_ref() { ast::Expr::Literal(_) => program.mark_last_insn_constant(), _ => {} @@ -323,7 +338,14 @@ pub fn translate_condition_expr( // The left hand side only needs to be evaluated once we have a list of values to compare against. let lhs_reg = program.alloc_register(); - let _ = translate_expr(program, Some(referenced_tables), lhs, lhs_reg, cursor_hint)?; + let _ = translate_expr( + program, + Some(referenced_tables), + lhs, + lhs_reg, + cursor_hint, + None, + )?; let rhs = rhs.as_ref().unwrap(); @@ -352,6 +374,7 @@ pub fn translate_condition_expr( expr, rhs_reg, cursor_hint, + None, )?; // If this is not the last condition, we need to jump to the 'jump_target_when_true' label if the condition is true. if !last_condition { @@ -395,6 +418,7 @@ pub fn translate_condition_expr( expr, rhs_reg, cursor_hint, + None, )?; program.emit_insn_with_label_dependency( Insn::Eq { @@ -444,6 +468,7 @@ pub fn translate_condition_expr( rhs, pattern_reg, cursor_hint, + None, )?; program.mark_last_insn_constant(); let _ = translate_expr( @@ -452,6 +477,7 @@ pub fn translate_condition_expr( lhs, column_reg, cursor_hint, + None, )?; program.emit_insn(Insn::Function { // Only constant patterns for LIKE are supported currently, so this @@ -516,20 +542,72 @@ pub fn translate_condition_expr( Ok(()) } +pub fn get_cached_or_translate( + program: &mut ProgramBuilder, + referenced_tables: Option<&[(Rc, String)]>, + expr: &ast::Expr, + cursor_hint: Option, + cached_results: Option<&Vec<&CachedResult>>, +) -> Result { + if let Some(cached_results) = cached_results { + if let Some(cached_result) = cached_results + .iter() + .find(|cached_result| cached_result.source_expr == *expr) + { + return Ok(cached_result.register_idx); + } + } + let reg = program.alloc_register(); + translate_expr( + program, + referenced_tables, + expr, + reg, + cursor_hint, + cached_results, + )?; + Ok(reg) +} + pub fn translate_expr( program: &mut ProgramBuilder, referenced_tables: Option<&[(Rc, String)]>, expr: &ast::Expr, target_register: usize, cursor_hint: Option, + cached_results: Option<&Vec<&CachedResult>>, ) -> Result { + if let Some(cached_results) = &cached_results { + if let Some(cached_result) = cached_results + .iter() + .find(|cached_result| cached_result.source_expr == *expr) + { + program.emit_insn(Insn::Copy { + src_reg: cached_result.register_idx, + dst_reg: target_register, + amount: 0, + }); + return Ok(target_register); + } + } + match expr { ast::Expr::Between { .. } => todo!(), ast::Expr::Binary(e1, op, e2) => { - let e1_reg = program.alloc_register(); - let _ = translate_expr(program, referenced_tables, e1, e1_reg, cursor_hint)?; - let e2_reg = program.alloc_register(); - let _ = translate_expr(program, referenced_tables, e2, e2_reg, cursor_hint)?; + let e1_reg = get_cached_or_translate( + program, + referenced_tables, + e1, + cursor_hint, + cached_results, + )?; + let e2_reg = get_cached_or_translate( + program, + referenced_tables, + e2, + cursor_hint, + cached_results, + )?; match op { ast::Operator::NotEquals => { @@ -617,6 +695,13 @@ pub fn translate_expr( dest: target_register, }); } + ast::Operator::Multiply => { + program.emit_insn(Insn::Multiply { + lhs: e1_reg, + rhs: e2_reg, + dest: target_register, + }); + } other_unimplemented => todo!("{:?}", other_unimplemented), } Ok(target_register) @@ -667,7 +752,14 @@ pub fn translate_expr( ); }; let regs = program.alloc_register(); - translate_expr(program, referenced_tables, &args[0], regs, cursor_hint)?; + translate_expr( + program, + referenced_tables, + &args[0], + regs, + cursor_hint, + cached_results, + )?; program.emit_insn(Insn::Function { constant_mask: 0, start_reg: regs, @@ -684,7 +776,14 @@ pub fn translate_expr( for arg in args.iter() { let reg = program.alloc_register(); - translate_expr(program, referenced_tables, arg, reg, cursor_hint)?; + translate_expr( + program, + referenced_tables, + arg, + reg, + cursor_hint, + cached_results, + )?; } program.emit_insn(Insn::Function { @@ -721,6 +820,7 @@ pub fn translate_expr( arg, target_register, cursor_hint, + cached_results, )?; if index < args.len() - 1 { program.emit_insn_with_label_dependency( @@ -747,7 +847,14 @@ pub fn translate_expr( }; for arg in args.iter() { let reg = program.alloc_register(); - translate_expr(program, referenced_tables, arg, reg, cursor_hint)?; + translate_expr( + program, + referenced_tables, + arg, + reg, + cursor_hint, + cached_results, + )?; } program.emit_insn(Insn::Function { constant_mask: 0, @@ -777,6 +884,7 @@ pub fn translate_expr( &args[0], temp_reg, cursor_hint, + cached_results, )?; program.emit_insn(Insn::NotNull { reg: temp_reg, @@ -789,6 +897,7 @@ pub fn translate_expr( &args[1], temp_reg, cursor_hint, + cached_results, )?; program.emit_insn(Insn::Copy { src_reg: temp_reg, @@ -821,6 +930,7 @@ pub fn translate_expr( arg, reg, cursor_hint, + cached_results, )?; match arg { ast::Expr::Literal(_) => program.mark_last_insn_constant(), @@ -865,6 +975,7 @@ pub fn translate_expr( &args[0], regs, cursor_hint, + cached_results, )?; program.emit_insn(Insn::Function { constant_mask: 0, @@ -901,6 +1012,7 @@ pub fn translate_expr( arg, target_reg, cursor_hint, + cached_results, )?; } } @@ -938,6 +1050,7 @@ pub fn translate_expr( &args[0], str_reg, cursor_hint, + cached_results, )?; translate_expr( program, @@ -945,6 +1058,7 @@ pub fn translate_expr( &args[1], start_reg, cursor_hint, + cached_results, )?; if args.len() == 3 { translate_expr( @@ -953,6 +1067,7 @@ pub fn translate_expr( &args[2], length_reg, cursor_hint, + cached_results, )?; } @@ -977,6 +1092,7 @@ pub fn translate_expr( &args[0], arg_reg, cursor_hint, + cached_results, )?; start_reg = arg_reg; } @@ -1000,6 +1116,7 @@ pub fn translate_expr( arg, target_reg, cursor_hint, + cached_results, )?; } } @@ -1032,7 +1149,14 @@ pub fn translate_expr( for arg in args.iter() { let reg = program.alloc_register(); - translate_expr(program, referenced_tables, arg, reg, cursor_hint)?; + translate_expr( + program, + referenced_tables, + arg, + reg, + cursor_hint, + cached_results, + )?; if let ast::Expr::Literal(_) = arg { program.mark_last_insn_constant(); } @@ -1064,6 +1188,7 @@ pub fn translate_expr( arg, reg, cursor_hint, + cached_results, )?; match arg { ast::Expr::Literal(_) => program.mark_last_insn_constant(), @@ -1098,6 +1223,7 @@ pub fn translate_expr( arg, reg, cursor_hint, + cached_results, )?; match arg { ast::Expr::Literal(_) => program.mark_last_insn_constant(), @@ -1132,6 +1258,7 @@ pub fn translate_expr( &args[0], first_reg, cursor_hint, + cached_results, )?; let second_reg = program.alloc_register(); translate_expr( @@ -1140,6 +1267,7 @@ pub fn translate_expr( &args[1], second_reg, cursor_hint, + cached_results, )?; program.emit_insn(Insn::Function { constant_mask: 0, @@ -1208,6 +1336,7 @@ pub fn translate_expr( ast::Literal::Null => { program.emit_insn(Insn::Null { dest: target_register, + dest_end: None, }); Ok(target_register) } @@ -1389,16 +1518,15 @@ pub fn maybe_apply_affinity(col_type: Type, target_register: usize, program: &mu pub fn translate_table_columns( program: &mut ProgramBuilder, - table: &Rc, - table_identifier: &str, - cursor_override: Option, + cursor_id: usize, + table: &Table, + start_column_offset: usize, start_reg: usize, ) -> usize { let mut cur_reg = start_reg; - let cursor_id = cursor_override.unwrap_or(program.resolve_cursor_id(table_identifier, None)); - for i in 0..table.columns.len() { - let is_rowid = table.column_is_rowid_alias(&table.columns[i]); - let col_type = &table.columns[i].ty; + for i in start_column_offset..table.columns().len() { + let is_rowid = table.column_is_rowid_alias(&table.get_column_at(i)); + let col_type = &table.get_column_at(i).ty; if is_rowid { program.emit_insn(Insn::RowId { cursor_id, @@ -1437,6 +1565,7 @@ pub fn translate_aggregation( expr, expr_reg, cursor_hint, + None, )?; program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -1458,6 +1587,7 @@ pub fn translate_aggregation( expr, expr_reg, cursor_hint, + None, ); expr_reg }; @@ -1505,6 +1635,7 @@ pub fn translate_aggregation( expr, expr_reg, cursor_hint, + None, )?; translate_expr( program, @@ -1512,6 +1643,7 @@ pub fn translate_aggregation( &delimiter_expr, delimiter_reg, cursor_hint, + None, )?; program.emit_insn(Insn::AggStep { @@ -1535,6 +1667,7 @@ pub fn translate_aggregation( expr, expr_reg, cursor_hint, + None, )?; program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -1556,6 +1689,7 @@ pub fn translate_aggregation( expr, expr_reg, cursor_hint, + None, )?; program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -1596,6 +1730,7 @@ pub fn translate_aggregation( expr, expr_reg, cursor_hint, + None, )?; translate_expr( program, @@ -1603,6 +1738,7 @@ pub fn translate_aggregation( &delimiter_expr, delimiter_reg, cursor_hint, + None, )?; program.emit_insn(Insn::AggStep { @@ -1626,6 +1762,7 @@ pub fn translate_aggregation( expr, expr_reg, cursor_hint, + None, )?; program.emit_insn(Insn::AggStep { acc_reg: target_register, @@ -1647,6 +1784,7 @@ pub fn translate_aggregation( expr, expr_reg, cursor_hint, + None, )?; program.emit_insn(Insn::AggStep { acc_reg: target_register, diff --git a/core/translate/insert.rs b/core/translate/insert.rs index e3b2ae7e9..7e04f4901 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -93,6 +93,7 @@ pub fn translate_insert( expr, column_registers_start + col, None, + None, )?; } program.emit_insn(Insn::Yield { diff --git a/core/translate/optimizer.rs b/core/translate/optimizer.rs index 7b9d68b01..01ef358a4 100644 --- a/core/translate/optimizer.rs +++ b/core/translate/optimizer.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::{collections::HashMap, rc::Rc}; use sqlite3_parser::ast; @@ -6,12 +6,14 @@ use crate::{schema::BTreeTable, util::normalize_ident, Result}; use super::plan::{ get_table_ref_bitmask_for_ast_expr, get_table_ref_bitmask_for_operator, Operator, Plan, + ProjectionColumn, }; /** * Make a few passes over the plan to optimize it. */ -pub fn optimize_plan(mut select_plan: Plan) -> Result { +pub fn optimize_plan(mut select_plan: Plan) -> Result<(Plan, ExpressionResultCache)> { + let mut expr_result_cache = ExpressionResultCache::new(); push_predicates( &mut select_plan.root_operator, &select_plan.referenced_tables, @@ -19,16 +21,20 @@ pub fn optimize_plan(mut select_plan: Plan) -> Result { if eliminate_constants(&mut select_plan.root_operator)? == ConstantConditionEliminationResult::ImpossibleCondition { - return Ok(Plan { - root_operator: Operator::Nothing, - referenced_tables: vec![], - }); + return Ok(( + Plan { + root_operator: Operator::Nothing, + referenced_tables: vec![], + }, + expr_result_cache, + )); } use_indexes( &mut select_plan.root_operator, &select_plan.referenced_tables, )?; - Ok(select_plan) + find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them(&select_plan.root_operator, &mut expr_result_cache); + Ok((select_plan, expr_result_cache)) } /** @@ -523,6 +529,383 @@ fn push_predicate( } } +#[derive(Debug)] +pub struct ExpressionResultCache { + resultmap: HashMap, + keymap: HashMap>, +} + +#[derive(Debug)] +pub struct CachedResult { + pub register_idx: usize, + pub source_expr: ast::Expr, +} + +const OPERATOR_ID_MULTIPLIER: usize = 10000; + +/** + ExpressionResultCache is a cache for the results of expressions that are computed in the query plan, + or more precisely, the VM registers that hold the results of these expressions. + + Right now the cache is mainly used to avoid recomputing e.g. the result of an aggregation expression + e.g. SELECT t.a, SUM(t.b) FROM t GROUP BY t.a ORDER BY SUM(t.b) +*/ +impl ExpressionResultCache { + pub fn new() -> Self { + ExpressionResultCache { + resultmap: HashMap::new(), + keymap: HashMap::new(), + } + } + + /** + Store the result of an expression that is computed in the query plan. + The result is stored in a VM register. A copy of the expression AST node is + stored as well, so that parent operators can use it to compare their own expressions + with the one that was computed in a child operator. + + This is a weakness of our current reliance on a 3rd party AST library, as we can't + e.g. modify the AST to add identifiers to nodes or replace nodes with some kind of + reference to a register, etc. + */ + pub fn cache_result_register( + &mut self, + operator_id: usize, + result_column_idx: usize, + register_idx: usize, + expr: ast::Expr, + ) { + let key = operator_id * OPERATOR_ID_MULTIPLIER + result_column_idx; + self.resultmap.insert( + key, + CachedResult { + register_idx, + source_expr: expr, + }, + ); + } + + /** + Set a mapping from a parent operator to a child operator, so that the parent operator + can look up the register of a result that was computed in the child operator. + E.g. "Parent operator's result column 3 is computed in child operator 5, result column 2" + */ + pub fn set_precomputation_key( + &mut self, + operator_id: usize, + result_column_idx: usize, + child_operator_id: usize, + child_operator_result_column_idx_mask: usize, + ) -> () { + let key = operator_id * OPERATOR_ID_MULTIPLIER + result_column_idx; + + let mut values = Vec::new(); + for i in 0..64 { + if (child_operator_result_column_idx_mask >> i) & 1 == 1 { + values.push(child_operator_id * OPERATOR_ID_MULTIPLIER + i); + } + } + self.keymap.insert(key, values); + } + + /** + Get the cache entries for a given operator and result column index. + There may be multiple cached entries, e.g. a binary operator's both + arms may have been cached. + */ + pub fn get_cached_result_registers( + &self, + operator_id: usize, + result_column_idx: usize, + ) -> Option> { + let key = operator_id * OPERATOR_ID_MULTIPLIER + result_column_idx; + self.keymap.get(&key).and_then(|keys| { + let mut results = Vec::new(); + for key in keys { + if let Some(result) = self.resultmap.get(key) { + results.push(result); + } + } + if results.is_empty() { + None + } else { + Some(results) + } + }) + } +} + +type ResultColumnIndexBitmask = usize; + +/** + Find all result columns in an operator that match an expression, either fully or partially. + This is used to find the result columns that are computed in an operator and that are used + in a parent operator, so that the parent operator can look up the register that holds the result + of the child operator's expression. + + The result is returned as a bitmask due to performance neuroticism. A limitation of this is that + we can only handle 64 result columns per operator. +*/ +fn find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially( + expr: &ast::Expr, + operator: &Operator, +) -> ResultColumnIndexBitmask { + let exact_match = match operator { + Operator::Aggregate { + aggregates, + group_by, + .. + } => { + let mut idx = 0; + let mut mask = 0; + for agg in aggregates.iter() { + if agg.original_expr == *expr { + mask |= 1 << idx; + } + idx += 1; + } + + if let Some(group_by) = group_by { + for g in group_by.iter() { + if g == expr { + mask |= 1 << idx; + } + idx += 1 + } + } + + mask + } + Operator::Filter { .. } => 0, + Operator::SeekRowid { .. } => 0, + Operator::Limit { .. } => 0, + Operator::Join { .. } => 0, + Operator::Order { .. } => 0, + Operator::Projection { expressions, .. } => { + let mut idx = 0; + let mut mask = 0; + for e in expressions.iter() { + match e { + super::plan::ProjectionColumn::Column(c) => { + if c == expr { + mask |= 1 << idx; + } + } + super::plan::ProjectionColumn::Star => {} + super::plan::ProjectionColumn::TableStar(_, _) => {} + } + idx += 1; + } + + mask + } + Operator::Scan { .. } => 0, + Operator::Nothing => 0, + }; + + if exact_match != 0 { + return exact_match; + } + + match expr { + ast::Expr::Between { + lhs, + not, + start, + end, + } => { + let mut mask = 0; + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(lhs, operator); + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(start, operator); + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(end, operator); + mask + } + ast::Expr::Binary(lhs, op, rhs) => { + let mut mask = 0; + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(lhs, operator); + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(rhs, operator); + mask + } + ast::Expr::Case { + base, + when_then_pairs, + else_expr, + } => { + let mut mask = 0; + if let Some(base) = base { + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(base, operator); + } + for (w, t) in when_then_pairs.iter() { + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(w, operator); + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(t, operator); + } + if let Some(e) = else_expr { + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(e, operator); + } + mask + } + ast::Expr::Cast { expr, type_name } => { + find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially( + expr, operator, + ) + } + ast::Expr::Collate(expr, collation) => { + find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially( + expr, operator, + ) + } + ast::Expr::DoublyQualified(schema, tbl, ident) => 0, + ast::Expr::Exists(_) => 0, + ast::Expr::FunctionCall { + name, + distinctness, + args, + order_by, + filter_over, + } => { + let mut mask = 0; + if let Some(args) = args { + for a in args.iter() { + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(a, operator); + } + } + mask + } + ast::Expr::FunctionCallStar { name, filter_over } => 0, + ast::Expr::Id(_) => 0, + ast::Expr::InList { lhs, not, rhs } => { + let mut mask = 0; + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(lhs, operator); + if let Some(rhs) = rhs { + for r in rhs.iter() { + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(r, operator); + } + } + mask + } + ast::Expr::InSelect { lhs, not, rhs } => { + find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially( + lhs, operator, + ) + } + ast::Expr::InTable { + lhs, + not, + rhs, + args, + } => 0, + ast::Expr::IsNull(expr) => { + find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially( + expr, operator, + ) + } + ast::Expr::Like { + lhs, + not, + op, + rhs, + escape, + } => { + let mut mask = 0; + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(lhs, operator); + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(rhs, operator); + mask + } + ast::Expr::Literal(_) => 0, + ast::Expr::Name(_) => 0, + ast::Expr::NotNull(expr) => { + find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially( + expr, operator, + ) + } + ast::Expr::Parenthesized(expr) => { + let mut mask = 0; + for e in expr.iter() { + mask |= find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(e, operator); + } + mask + } + ast::Expr::Qualified(_, _) => 0, + ast::Expr::Raise(_, _) => 0, + ast::Expr::Subquery(_) => 0, + ast::Expr::Unary(op, expr) => { + find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially( + expr, operator, + ) + } + ast::Expr::Variable(_) => 0, + } +} + +/** + * This function is used to find all the expressions that are shared between the parent operator and the child operators. + * If an expression is shared between the parent and child operators, then the parent operator should not recompute the expression. + * Instead, it should use the result of the expression that was computed by the child operator. +*/ +fn find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them( + operator: &Operator, + expr_result_cache: &mut ExpressionResultCache, +) { + match operator { + Operator::Aggregate { + source, + aggregates, + group_by, + .. + } => { + find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them( + source, expr_result_cache, + ) + } + Operator::Filter { .. } => unreachable!(), + Operator::SeekRowid { .. } => {} + Operator::Limit { source, .. } => { + find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them(source, expr_result_cache) + } + Operator::Join { .. } => {} + Operator::Order { source, key, .. } => { + let mut idx = 0; + + for (expr, _) in key.iter() { + let result = find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(&expr, source); + if result != 0 { + expr_result_cache.set_precomputation_key( + operator.id(), + idx, + source.id(), + result, + ); + } + idx += 1; + } + find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them(source, expr_result_cache) + } + Operator::Projection { source, expressions, .. } => { + let mut idx = 0; + for expr in expressions.iter() { + match expr { + ProjectionColumn::Column(expr) => { + let result = find_indexes_of_all_result_columns_in_operator_that_match_expr_either_fully_or_partially(&expr, source); + if result != 0 { + expr_result_cache.set_precomputation_key( + operator.id(), + idx, + source.id(), + result, + ); + } + } + _ => {} + } + idx += 1; + } + find_shared_expressions_in_child_operators_and_mark_them_so_that_the_parent_operator_doesnt_recompute_them(source, expr_result_cache) + } + Operator::Scan { .. } => {} + Operator::Nothing => {} + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ConstantPredicate { AlwaysTrue, @@ -762,7 +1145,3 @@ impl TakeOwnership for Operator { std::mem::replace(self, Operator::Nothing) } } - -fn replace_with(expr: &mut T, mut replacement: T) { - *expr = replacement.take_ownership(); -} diff --git a/core/translate/plan.rs b/core/translate/plan.rs index eb9f321ca..6a73e4b57 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -43,6 +43,7 @@ pub enum Operator { id: usize, source: Box, aggregates: Vec, + group_by: Option>, step: usize, }, // Filter operator @@ -154,7 +155,11 @@ impl ProjectionColumn { impl Operator { pub fn column_count(&self, referenced_tables: &[(Rc, String)]) -> usize { match self { - Operator::Aggregate { aggregates, .. } => aggregates.len(), + Operator::Aggregate { + group_by, + aggregates, + .. + } => aggregates.len() + group_by.as_ref().map_or(0, |g| g.len()), Operator::Filter { source, .. } => source.column_count(referenced_tables), Operator::SeekRowid { table, .. } => table.columns.len(), Operator::Limit { source, .. } => source.column_count(referenced_tables), @@ -173,8 +178,29 @@ impl Operator { pub fn column_names(&self) -> Vec { match self { - Operator::Aggregate { .. } => { - todo!(); + Operator::Aggregate { + aggregates, + group_by, + .. + } => { + let mut names = vec![]; + for agg in aggregates.iter() { + names.push(agg.func.to_string().to_string()); + } + + if let Some(group_by) = group_by { + for expr in group_by.iter() { + match expr { + ast::Expr::Id(ident) => names.push(ident.0.clone()), + ast::Expr::Qualified(tbl, ident) => { + names.push(format!("{}.{}", tbl.0, ident.0)) + } + e => names.push(e.to_string()), + } + } + } + + names } Operator::Filter { source, .. } => source.column_names(), Operator::SeekRowid { table, .. } => { @@ -238,6 +264,7 @@ impl Display for Direction { pub struct Aggregate { pub func: AggFunc, pub args: Vec, + pub original_expr: ast::Expr, } impl Display for Aggregate { diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 94ca386b9..f953da554 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -14,7 +14,7 @@ pub struct OperatorIdCounter { impl OperatorIdCounter { pub fn new() -> Self { - Self { id: 0 } + Self { id: 1 } } pub fn get_next_id(&mut self) -> usize { let id = self.id; @@ -23,12 +23,54 @@ impl OperatorIdCounter { } } +fn resolve_aggregates(expr: &ast::Expr, aggs: &mut Vec) { + match expr { + ast::Expr::FunctionCall { name, args, .. } => { + let args_count = if let Some(args) = &args { + args.len() + } else { + 0 + }; + match Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), args_count) { + Ok(Func::Agg(f)) => aggs.push(Aggregate { + func: f, + args: args.clone().unwrap_or_default(), + original_expr: expr.clone(), + }), + _ => { + if let Some(args) = args { + for arg in args.iter() { + resolve_aggregates(&arg, aggs); + } + } + } + } + } + ast::Expr::FunctionCallStar { name, .. } => { + match Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), 0) { + Ok(Func::Agg(f)) => aggs.push(Aggregate { + func: f, + args: vec![], + original_expr: expr.clone(), + }), + _ => {} + } + } + ast::Expr::Binary(lhs, _, rhs) => { + resolve_aggregates(lhs, aggs); + resolve_aggregates(rhs, aggs); + } + _ => {} + } +} + pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result { match select.body.select { ast::OneSelect::Select { columns, from, where_clause, + group_by, .. } => { let col_count = columns.len(); @@ -53,21 +95,17 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result

{ - scalar_expressions.push(ProjectionColumn::Star); + projection_expressions.push(ProjectionColumn::Star); } ast::ResultColumn::TableStar(name) => { let name_normalized = normalize_ident(name.0.as_str()); @@ -79,89 +117,98 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result

match expr { - ast::Expr::FunctionCall { - name, - distinctness, - args, - filter_over, - order_by, - } => { - let args_count = if let Some(args) = &args { - args.len() - } else { - 0 - }; - match Func::resolve_function( - normalize_ident(name.0.as_str()).as_str(), - args_count, - ) { - Ok(Func::Agg(f)) => aggregate_expressions.push(Aggregate { - func: f, - args: args.unwrap(), - }), - Ok(_) => { - scalar_expressions.push(ProjectionColumn::Column( - ast::Expr::FunctionCall { - name, - distinctness, - args, - filter_over, - order_by, - }, - )); + ast::ResultColumn::Expr(expr, _) => { + projection_expressions.push(ProjectionColumn::Column(expr.clone())); + match expr.clone() { + ast::Expr::FunctionCall { + name, + distinctness, + args, + filter_over, + order_by, + } => { + let args_count = if let Some(args) = &args { + args.len() + } else { + 0 + }; + match Func::resolve_function( + normalize_ident(name.0.as_str()).as_str(), + args_count, + ) { + Ok(Func::Agg(f)) => { + aggregate_expressions.push(Aggregate { + func: f, + args: args.unwrap(), + original_expr: expr.clone(), + }); + } + Ok(_) => { + resolve_aggregates(&expr, &mut aggregate_expressions); + } + _ => {} } - _ => {} } - } - ast::Expr::FunctionCallStar { name, filter_over } => { - match Func::resolve_function( - normalize_ident(name.0.as_str()).as_str(), - 0, - ) { - Ok(Func::Agg(f)) => aggregate_expressions.push(Aggregate { - func: f, - args: vec![], - }), - Ok(Func::Scalar(_)) => { - scalar_expressions.push(ProjectionColumn::Column( - ast::Expr::FunctionCallStar { name, filter_over }, - )); + ast::Expr::FunctionCallStar { name, filter_over } => { + match Func::resolve_function( + normalize_ident(name.0.as_str()).as_str(), + 0, + ) { + Ok(Func::Agg(f)) => { + aggregate_expressions.push(Aggregate { + func: f, + args: vec![], + original_expr: expr.clone(), + }); + } + _ => {} } - _ => {} } + ast::Expr::Binary(lhs, _, rhs) => { + resolve_aggregates(&lhs, &mut aggregate_expressions); + resolve_aggregates(&rhs, &mut aggregate_expressions); + } + _ => {} } - _ => { - scalar_expressions.push(ProjectionColumn::Column(expr)); - } - }, + } } } - - let mixing_aggregate_and_non_aggregate_columns = - !aggregate_expressions.is_empty() && aggregate_expressions.len() != col_count; - - if mixing_aggregate_and_non_aggregate_columns { - crate::bail_parse_error!( - "mixing aggregate and non-aggregate columns is not allowed (GROUP BY is not supported)" - ); + if let Some(group_by) = group_by.as_ref() { + if aggregate_expressions.is_empty() { + crate::bail_parse_error!( + "GROUP BY clause without aggregate functions is not allowed" + ); + } + for scalar in projection_expressions.iter() { + match scalar { + ProjectionColumn::Column(_) => {} + _ => { + crate::bail_parse_error!( + "Only column references are allowed in the SELECT clause when using GROUP BY" + ); + } + } + } } if !aggregate_expressions.is_empty() { operator = Operator::Aggregate { source: Box::new(operator), aggregates: aggregate_expressions, + group_by: group_by.map(|g| g.exprs), // TODO: support HAVING id: operator_id_counter.get_next_id(), step: 0, } - } else if !scalar_expressions.is_empty() { + } + + if !projection_expressions.is_empty() { operator = Operator::Projection { source: Box::new(operator), - expressions: scalar_expressions, + expressions: projection_expressions, id: operator_id_counter.get_next_id(), step: 0, }; @@ -171,17 +218,18 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result

()?; if column_number == 0 { crate::bail_parse_error!("invalid column index: {}", column_number); } let maybe_result_column = columns.get(column_number - 1); match maybe_result_column { - Some(ResultColumn::Expr(expr, _)) => expr.clone(), + Some(ResultColumn::Expr(e, _)) => e.clone(), None => { crate::bail_parse_error!("invalid column index: {}", column_number) } @@ -190,6 +238,7 @@ pub fn prepare_select_plan<'a>(schema: &Schema, select: ast::Select) -> Result

>, ) -> Result { let select_plan = prepare_select_plan(schema, select)?; - let optimized_plan = optimize_plan(select_plan)?; - emit_program(database_header, optimized_plan) + let (optimized_plan, expr_result_cache) = optimize_plan(select_plan)?; + emit_program(database_header, optimized_plan, expr_result_cache) } diff --git a/core/types.rs b/core/types.rs index 84eeaca8e..415457b76 100644 --- a/core/types.rs +++ b/core/types.rs @@ -69,6 +69,21 @@ pub enum AggContext { GroupConcat(OwnedValue), } +const NULL: OwnedValue = OwnedValue::Null; + +impl AggContext { + pub fn final_value(&self) -> &OwnedValue { + match self { + AggContext::Avg(acc, _count) => acc, + AggContext::Sum(acc) => acc, + AggContext::Count(count) => count, + AggContext::Max(max) => max.as_ref().unwrap_or(&NULL), + AggContext::Min(min) => min.as_ref().unwrap_or(&NULL), + AggContext::GroupConcat(s) => s, + } + } +} + impl std::cmp::PartialOrd for OwnedValue { fn partial_cmp(&self, other: &Self) -> Option { match (self, other) { @@ -93,6 +108,21 @@ impl std::cmp::PartialOrd for OwnedValue { (OwnedValue::Null, OwnedValue::Null) => Some(std::cmp::Ordering::Equal), (OwnedValue::Null, _) => Some(std::cmp::Ordering::Less), (_, OwnedValue::Null) => Some(std::cmp::Ordering::Greater), + (OwnedValue::Agg(a), OwnedValue::Agg(b)) => a.partial_cmp(b), + _ => None, + } + } +} + +impl std::cmp::PartialOrd for AggContext { + fn partial_cmp(&self, other: &AggContext) -> Option { + match (self, other) { + (AggContext::Avg(a, _), AggContext::Avg(b, _)) => a.partial_cmp(b), + (AggContext::Sum(a), AggContext::Sum(b)) => a.partial_cmp(b), + (AggContext::Count(a), AggContext::Count(b)) => a.partial_cmp(b), + (AggContext::Max(a), AggContext::Max(b)) => a.partial_cmp(b), + (AggContext::Min(a), AggContext::Min(b)) => a.partial_cmp(b), + (AggContext::GroupConcat(a), AggContext::GroupConcat(b)) => a.partial_cmp(b), _ => None, } } diff --git a/core/vdbe/builder.rs b/core/vdbe/builder.rs index 721720a3e..681d06119 100644 --- a/core/vdbe/builder.rs +++ b/core/vdbe/builder.rs @@ -288,6 +288,16 @@ impl ProgramBuilder { assert!(*target_pc < 0); *target_pc = to_offset; } + Insn::Gosub { target_pc, .. } => { + assert!(*target_pc < 0); + *target_pc = to_offset; + } + Insn::Jump { target_pc_eq, .. } => { + // FIXME: this current implementation doesnt scale for insns that + // have potentially multiple label dependencies. + assert!(*target_pc_eq < 0); + *target_pc_eq = to_offset; + } _ => { todo!("missing resolve_label for {:?}", insn); } @@ -315,6 +325,10 @@ impl ProgramBuilder { .unwrap() } + pub fn resolve_cursor_to_table(&self, cursor_id: CursorID) -> Option { + self.cursor_ref[cursor_id].1.clone() + } + pub fn resolve_deferred_labels(&mut self) { for i in 0..self.deferred_label_resolutions.len() { let (label, insn_reference) = self.deferred_label_resolutions[i]; diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index cddc7e9e3..43c49903a 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -28,14 +28,25 @@ pub fn insn_to_str( 0, format!("r[{}]=r[{}]+r[{}]", dest, lhs, rhs), ), - Insn::Null { dest } => ( - "Null", + Insn::Multiply { lhs, rhs, dest } => ( + "Multiply", + *lhs as i32, + *rhs as i32, *dest as i32, - 0, - 0, OwnedValue::Text(Rc::new("".to_string())), 0, - format!("r[{}]=NULL", dest), + format!("r[{}]=r[{}]*r[{}]", dest, lhs, rhs), + ), + Insn::Null { dest, dest_end } => ( + "Null", + 0, + *dest as i32, + dest_end.map_or(0, |end| end as i32), + OwnedValue::Text(Rc::new("".to_string())), + 0, + dest_end.map_or(format!("r[{}]=NULL", dest), |end| { + format!("r[{}..{}]=NULL", dest, end) + }), ), Insn::NullRow { cursor_id } => ( "NullRow", @@ -55,6 +66,57 @@ pub fn insn_to_str( 0, format!("r[{}]!=NULL -> goto {}", reg, target_pc), ), + Insn::Compare { + start_reg_a, + start_reg_b, + count, + } => ( + "Compare", + *start_reg_a as i32, + *start_reg_b as i32, + *count as i32, + OwnedValue::Text(Rc::new("".to_string())), + 0, + format!( + "r[{}..{}]==r[{}..{}]", + start_reg_a, + start_reg_a + (count - 1), + start_reg_b, + start_reg_b + (count - 1) + ), + ), + Insn::Jump { + target_pc_lt, + target_pc_eq, + target_pc_gt, + } => ( + "Jump", + *target_pc_lt as i32, + *target_pc_eq as i32, + *target_pc_gt as i32, + OwnedValue::Text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::Move { + source_reg, + dest_reg, + count, + } => ( + "Move", + *source_reg as i32, + *dest_reg as i32, + *count as i32, + OwnedValue::Text(Rc::new("".to_string())), + 0, + format!( + "r[{}..{}]=r[{}..{}]", + dest_reg, + dest_reg + (count - 1), + source_reg, + source_reg + (count - 1) + ), + ), Insn::IfPos { reg, target_pc, @@ -348,6 +410,27 @@ pub fn insn_to_str( 0, "".to_string(), ), + Insn::Gosub { + target_pc, + return_reg, + } => ( + "Gosub", + *return_reg as i32, + *target_pc as i32, + 0, + OwnedValue::Text(Rc::new("".to_string())), + 0, + "".to_string(), + ), + Insn::Return { return_reg } => ( + "Return", + *return_reg as i32, + 0, + 0, + OwnedValue::Text(Rc::new("".to_string())), + 0, + "".to_string(), + ), Insn::Integer { value, dest } => ( "Integer", *value as i32, @@ -478,7 +561,11 @@ pub fn insn_to_str( *cursor_id as i32, *columns as i32, 0, - OwnedValue::Text(Rc::new(format!("k({},{})", columns, to_print.join(",")))), + OwnedValue::Text(Rc::new(format!( + "k({},{})", + order.values.len(), + to_print.join(",") + ))), 0, format!("cursor={}", cursor_id), ) diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index eeecca3ec..e5f24693f 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -70,9 +70,10 @@ pub enum Insn { Init { target_pc: BranchOffset, }, - // Set NULL in the given register. + // Write a NULL into register dest. If dest_end is Some, then also write NULL into register dest_end and every register in between dest and dest_end. If dest_end is not set, then only register dest is set to NULL. Null { dest: usize, + dest_end: Option, }, // Move the cursor P1 to a null row. Any Column operations that occur while the cursor is on the null row will always write a NULL. NullRow { @@ -84,6 +85,30 @@ pub enum Insn { rhs: usize, dest: usize, }, + // Multiply two registers and store the result in a third register. + Multiply { + lhs: usize, + rhs: usize, + dest: usize, + }, + // Compare two vectors of registers in reg(P1)..reg(P1+P3-1) (call this vector "A") and in reg(P2)..reg(P2+P3-1) ("B"). Save the result of the comparison for use by the next Jump instruct. + Compare { + start_reg_a: usize, + start_reg_b: usize, + count: usize, + }, + // Jump to the instruction at address P1, P2, or P3 depending on whether in the most recent Compare instruction the P1 vector was less than, equal to, or greater than the P2 vector, respectively. + Jump { + target_pc_lt: BranchOffset, + target_pc_eq: BranchOffset, + target_pc_gt: BranchOffset, + }, + // Move the P3 values in register P1..P1+P3-1 over into registers P2..P2+P3-1. Registers P1..P1+P3-1 are left holding a NULL. It is an error for register ranges P1..P1+P3-1 and P2..P2+P3-1 to overlap. It is an error for P3 to be less than 1. + Move { + source_reg: usize, + dest_reg: usize, + count: usize, + }, // If the given register is a positive integer, decrement it by decrement_by and jump to the given PC. IfPos { reg: usize, @@ -214,6 +239,17 @@ pub enum Insn { target_pc: BranchOffset, }, + // Stores the current program counter into register 'return_reg' then jumps to address target_pc. + Gosub { + target_pc: BranchOffset, + return_reg: usize, + }, + + // Returns to the program counter stored in register 'return_reg'. + Return { + return_reg: usize, + }, + // Write an integer value into a register. Integer { value: i64, @@ -382,6 +418,7 @@ pub struct ProgramState { pub pc: BranchOffset, cursors: RefCell>>, registers: Vec, + last_compare: Option, ended_coroutine: bool, // flag to notify yield coroutine finished regex_cache: HashMap, } @@ -395,6 +432,7 @@ impl ProgramState { pc: 0, cursors, registers, + last_compare: None, ended_coroutine: false, regex_cache: HashMap::new(), } @@ -464,14 +502,123 @@ impl Program { (OwnedValue::Null, _) | (_, OwnedValue::Null) => { state.registers[dest] = OwnedValue::Null; } + (OwnedValue::Agg(aggctx), other) | (other, OwnedValue::Agg(aggctx)) => { + match other { + OwnedValue::Null => { + state.registers[dest] = OwnedValue::Null; + } + OwnedValue::Integer(i) => match aggctx.final_value() { + OwnedValue::Float(acc) => { + state.registers[dest] = OwnedValue::Float(acc + *i as f64); + } + OwnedValue::Integer(acc) => { + state.registers[dest] = OwnedValue::Integer(acc + i); + } + _ => { + todo!("{:?}", aggctx); + } + }, + OwnedValue::Float(f) => match aggctx.final_value() { + OwnedValue::Float(acc) => { + state.registers[dest] = OwnedValue::Float(acc + f); + } + OwnedValue::Integer(acc) => { + state.registers[dest] = OwnedValue::Float(*acc as f64 + f); + } + _ => { + todo!("{:?}", aggctx); + } + }, + OwnedValue::Agg(aggctx2) => { + let acc = aggctx.final_value(); + let acc2 = aggctx2.final_value(); + match (acc, acc2) { + (OwnedValue::Integer(acc), OwnedValue::Integer(acc2)) => { + state.registers[dest] = OwnedValue::Integer(acc + acc2); + } + (OwnedValue::Float(acc), OwnedValue::Float(acc2)) => { + state.registers[dest] = OwnedValue::Float(acc + acc2); + } + (OwnedValue::Integer(acc), OwnedValue::Float(acc2)) => { + state.registers[dest] = + OwnedValue::Float(*acc as f64 + acc2); + } + (OwnedValue::Float(acc), OwnedValue::Integer(acc2)) => { + state.registers[dest] = + OwnedValue::Float(acc + *acc2 as f64); + } + _ => { + todo!("{:?} {:?}", acc, acc2); + } + } + } + rest => unimplemented!("{:?}", rest), + } + } _ => { todo!(); } } state.pc += 1; } - Insn::Null { dest } => { - state.registers[*dest] = OwnedValue::Null; + Insn::Multiply { lhs, rhs, dest } => { + let lhs = *lhs; + let rhs = *rhs; + let dest = *dest; + match (&state.registers[lhs], &state.registers[rhs]) { + (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { + state.registers[dest] = OwnedValue::Integer(lhs * rhs); + } + (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => { + state.registers[dest] = OwnedValue::Float(lhs * rhs); + } + (OwnedValue::Null, _) | (_, OwnedValue::Null) => { + state.registers[dest] = OwnedValue::Null; + } + (OwnedValue::Agg(aggctx), other) | (other, OwnedValue::Agg(aggctx)) => { + match other { + OwnedValue::Null => { + state.registers[dest] = OwnedValue::Null; + } + OwnedValue::Integer(i) => match aggctx.final_value() { + OwnedValue::Float(acc) => { + state.registers[dest] = OwnedValue::Float(acc * *i as f64); + } + OwnedValue::Integer(acc) => { + state.registers[dest] = OwnedValue::Integer(acc * i); + } + _ => { + todo!("{:?}", aggctx); + } + }, + OwnedValue::Float(f) => match aggctx.final_value() { + OwnedValue::Float(acc) => { + state.registers[dest] = OwnedValue::Float(acc * f); + } + OwnedValue::Integer(acc) => { + state.registers[dest] = OwnedValue::Float(*acc as f64 * f); + } + _ => { + todo!("{:?}", aggctx); + } + }, + rest => unimplemented!("{:?}", rest), + } + } + others => { + todo!("{:?}", others); + } + } + state.pc += 1; + } + Insn::Null { dest, dest_end } => { + if let Some(dest_end) = dest_end { + for i in *dest..=*dest_end { + state.registers[i] = OwnedValue::Null; + } + } else { + state.registers[*dest] = OwnedValue::Null; + } state.pc += 1; } Insn::NullRow { cursor_id } => { @@ -479,6 +626,68 @@ impl Program { cursor.set_null_flag(true); state.pc += 1; } + Insn::Compare { + start_reg_a, + start_reg_b, + count, + } => { + let start_reg_a = *start_reg_a; + let start_reg_b = *start_reg_b; + let count = *count; + + if start_reg_a + count > start_reg_b { + return Err(LimboError::InternalError( + "Compare registers overlap".to_string(), + )); + } + + let mut cmp = None; + for i in 0..count { + let a = &state.registers[start_reg_a + i]; + let b = &state.registers[start_reg_b + i]; + cmp = Some(a.cmp(b)); + if cmp != Some(std::cmp::Ordering::Equal) { + break; + } + } + state.last_compare = cmp; + state.pc += 1; + } + Insn::Jump { + target_pc_lt, + target_pc_eq, + target_pc_gt, + } => { + let cmp = state.last_compare.take(); + if cmp.is_none() { + return Err(LimboError::InternalError( + "Jump without compare".to_string(), + )); + } + let target_pc = match cmp.unwrap() { + std::cmp::Ordering::Less => *target_pc_lt, + std::cmp::Ordering::Equal => *target_pc_eq, + std::cmp::Ordering::Greater => *target_pc_gt, + }; + assert!(target_pc >= 0); + state.pc = target_pc; + } + Insn::Move { + source_reg, + dest_reg, + count, + } => { + let source_reg = *source_reg; + let dest_reg = *dest_reg; + let count = *count; + for i in 0..count { + state.registers[dest_reg + i] = std::mem::replace( + &mut state.registers[source_reg + i], + OwnedValue::Null, + ); + } + state.pc += 1; + } Insn::IfPos { reg, target_pc, @@ -788,6 +997,28 @@ impl Program { assert!(*target_pc >= 0); state.pc = *target_pc; } + Insn::Gosub { + target_pc, + return_reg, + } => { + assert!(*target_pc >= 0); + state.registers[*return_reg] = OwnedValue::Integer(state.pc as i64 + 1); + state.pc = *target_pc; + } + Insn::Return { return_reg } => { + if let OwnedValue::Integer(pc) = state.registers[*return_reg] { + if pc < 0 { + return Err(LimboError::InternalError( + "Return register is negative".to_string(), + )); + } + state.pc = pc; + } else { + return Err(LimboError::InternalError( + "Return register is not an integer".to_string(), + )); + } + } Insn::Integer { value, dest } => { state.registers[*dest] = OwnedValue::Integer(*value); state.pc += 1; @@ -1572,6 +1803,7 @@ fn exec_length(reg: &OwnedValue) -> OwnedValue { OwnedValue::Integer(reg.to_string().len() as i64) } OwnedValue::Blob(blob) => OwnedValue::Integer(blob.len() as i64), + OwnedValue::Agg(aggctx) => exec_length(&aggctx.final_value()), _ => reg.to_owned(), } } diff --git a/core/vdbe/sorter.rs b/core/vdbe/sorter.rs index 011392d1a..da214e531 100644 --- a/core/vdbe/sorter.rs +++ b/core/vdbe/sorter.rs @@ -97,7 +97,7 @@ impl Cursor for Sorter { let _ = moved_before; let key_fields = self.order.len(); let key = OwnedRecord::new(record.values[0..key_fields].to_vec()); - self.insert(key, OwnedRecord::new(record.values[key_fields..].to_vec())); + self.insert(key, OwnedRecord::new(record.values.to_vec())); Ok(CursorResult::Ok(())) } diff --git a/testing/all.test b/testing/all.test index 3c5457612..e4455c9ea 100755 --- a/testing/all.test +++ b/testing/all.test @@ -11,8 +11,9 @@ source $testdir/join.test source $testdir/json.test source $testdir/like.test source $testdir/orderby.test +source $testdir/groupby.test source $testdir/pragma.test source $testdir/scalar-functions.test source $testdir/scalar-functions-datetime.test source $testdir/select.test -source $testdir/where.test \ No newline at end of file +source $testdir/where.test diff --git a/testing/groupby.test b/testing/groupby.test new file mode 100644 index 000000000..a27d2a099 --- /dev/null +++ b/testing/groupby.test @@ -0,0 +1,107 @@ +#!/usr/bin/env tclsh + +set testdir [file dirname $argv0] +source $testdir/tester.tcl + +do_execsql_test group_by { + select u.first_name, sum(u.age) from users u group by u.first_name limit 10; +} {Aaron|2271 +Abigail|890 +Adam|1642 +Adrian|439 +Adriana|83 +Adrienne|318 +Aimee|33 +Alan|551 +Albert|369 +Alec|247} + +do_execsql_test group_by_two_joined_columns { + select u.first_name, p.name, sum(u.age) from users u join products p on u.id = p.id group by u.first_name, p.name limit 10; +} {Aimee|jeans|24 +Cindy|cap|37 +Daniel|coat|13 +Edward|sweatshirt|15 +Jamie|hat|94 +Jennifer|sweater|33 +Matthew|boots|77 +Nicholas|shorts|89 +Rachel|sneakers|63 +Tommy|shirt|18} + +do_execsql_test group_by_order_by { + select u.first_name, p.name, sum(u.age) from users u join products p on u.id = p.id group by u.first_name, p.name order by p.name limit 10; +} {Travis|accessories|22 +Matthew|boots|77 +Cindy|cap|37 +Daniel|coat|13 +Jamie|hat|94 +Aimee|jeans|24 +Tommy|shirt|18 +Nicholas|shorts|89 +Rachel|sneakers|63 +Jennifer|sweater|33} + +do_execsql_test group_by_order_by_aggregate { + select u.first_name, p.name, sum(u.age) from users u join products p on u.id = p.id group by u.first_name, p.name order by sum(u.age) limit 10; +} {Daniel|coat|13 +Edward|sweatshirt|15 +Tommy|shirt|18 +Travis|accessories|22 +Aimee|jeans|24 +Jennifer|sweater|33 +Cindy|cap|37 +Rachel|sneakers|63 +Matthew|boots|77 +Nicholas|shorts|89} + +do_execsql_test group_by_multiple_aggregates { + select u.first_name, sum(u.age), count(u.age) from users u group by u.first_name order by sum(u.age) limit 10; +} {Jaclyn|1|1 +Mia|1|1 +Kirsten|7|1 +Kellie|8|1 +Makayla|8|1 +Yvette|9|1 +Mckenzie|12|1 +Grant|14|1 +Mackenzie|15|1 +Cesar|17|1} + +do_execsql_test group_by_multiple_aggregates_2 { + select u.first_name, sum(u.age), group_concat(u.age) from users u group by u.first_name order by u.first_name limit 10; +} {Aaron|2271|52,46,17,69,71,91,34,30,97,81,47,98,45,69,97,18,38,26,98,60,33,97,42,43,43,22,18,75,56,67,83,58,82,28,22,72,5,58,96,32,55 +Abigail|890|17,82,62,57,55,5,9,83,93,22,23,57,56,100,74,95 +Adam|1642|34,23,10,11,46,40,2,57,51,80,65,24,15,84,59,6,34,100,32,79,57,5,77,34,30,19,54,74,89,98,72,91,90 +Adrian|439|37,28,94,76,69,60,34,41 +Adriana|83|83 +Adrienne|318|79,74,82,33,50 +Aimee|33|24,9 +Alan|551|18,52,30,62,96,13,85,97,98 +Albert|369|99,80,41,7,64,7,26,41,4 +Alec|247|55,48,53,91} + +do_execsql_test group_by_complex_order_by { + select u.first_name, group_concat(u.last_name) from users u group by u.first_name order by -1 * length(group_concat(u.last_name)) limit 1; +} {Michael|Love,Finley,Hurst,Molina,Williams,Brown,King,Whitehead,Ochoa,Davis,Rhodes,Mcknight,Reyes,Johnston,Smith,Young,Lopez,Roberts,Green,Cole,Lane,Wagner,Allen,Simpson,Schultz,Perry,Mendez,Gibson,Hale,Williams,Bradford,Johnson,Weber,Nunez,Walls,Gonzalez,Park,Blake,Vazquez,Garcia,Mathews,Pacheco,Johnson,Perez,Gibson,Sparks,Chapman,Tate,Dudley,Miller,Alvarado,Ward,Nguyen,Rosales,Flynn,Ball,Jones,Hoffman,Clarke,Rivera,Moore,Hardin,Dillon,Montgomery,Rodgers,Payne,Williams,Mueller,Hernandez,Ware,Yates,Grimes,Gilmore,Johnson,Clark,Rodriguez,Walters,Powell,Colon,Mccoy,Allen,Quinn,Dunn,Wilson,Thompson,Bradford,Hunter,Gilmore,Woods,Bennett,Collier,Ali,Herrera,Lawson,Garner,Perez,Brown,Pena,Allen,Davis,Washington,Jackson,Khan,Martinez,Blackwell,Lee,Parker,Lynn,Johnson,Benton,Leonard,Munoz,Alvarado,Mathews,Salazar,Nelson,Jones,Carpenter,Walter,Young,Coleman,Berry,Clark,Powers,Meyer,Lewis,Barton,Guzman,Schneider,Hernandez,Mclaughlin,Allen,Atkinson,Woods,Rivera,Jones,Gordon,Dennis,Yoder,Hunt,Vance,Nelson,Park,Barnes,Lang,Williams,Cervantes,Tran,Anderson,Todd,Gonzalez,Lowery,Sanders,Mccullough,Haley,Rogers,Perez,Watson,Weaver,Wise,Walter,Summers,Long,Chan,Williams,Mccoy,Duncan,Roy,West,Christensen,Cuevas,Garcia,Williams,Butler,Anderson,Armstrong,Villarreal,Boyer,Johnson,Dyer,Hurst,Wilkins,Mercer,Taylor,Montes,Mccarty,Gill,Rodriguez,Williams,Copeland,Hansen,Palmer,Alexander,White,Taylor,Bowers,Hughes,Gibbs,Myers,Kennedy,Sanchez,Bell,Wilson,Berry,Spears,Patton,Rose,Smith,Bowen,Nicholson,Stewart,Quinn,Powell,Delgado,Mills,Duncan,Phillips,Grant,Hatfield,Russell,Anderson,Reed,Mahoney,Mcguire,Ortega,Logan,Schmitt,Walker} + +do_execsql_test group_by_complex_order_by_2 { + select u.first_name, sum(u.age) from users u group by u.first_name order by -1 * sum(u.age) limit 10; +} {Michael|11204 +David|8758 +Robert|8109 +Jennifer|7700 +John|7299 +Christopher|6397 +James|5921 +Joseph|5711 +Brian|5059 +William|5047} + +do_execsql_test group_by_and_binary_expression_that_depends_on_two_aggregates { + select u.first_name, sum(u.age) + count(1) from users u group by u.first_name limit 5; +} {Aaron|2312 +Abigail|906 +Adam|1675 +Adrian|447 +Adriana|84}