diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 77274df84..5f345036d 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -4,7 +4,7 @@ use std::rc::Rc; use std::sync::Arc; -use limbo_sqlite3_parser::ast::{self}; +use limbo_sqlite3_parser::ast::{self, SortOrder}; use super::aggregation::emit_ungrouped_aggregation; use super::expr::{translate_condition_expr, translate_expr, ConditionMetadata}; @@ -15,13 +15,15 @@ use super::main_loop::{ close_loop, emit_loop, init_distinct, init_loop, open_loop, LeftJoinMetadata, LoopLabels, }; use super::order_by::{emit_order_by, init_order_by, SortMetadata}; -use super::plan::{JoinOrderMember, Operation, SelectPlan, TableReference, UpdatePlan}; +use super::plan::{ + JoinOrderMember, Operation, SelectPlan, SelectQueryType, TableReference, UpdatePlan, +}; use super::schema::ParseSchema; use super::select::emit_simple_count; use super::subquery::emit_subqueries; use crate::error::SQLITE_CONSTRAINT_PRIMARYKEY; use crate::function::Func; -use crate::schema::Index; +use crate::schema::{Index, IndexColumn}; use crate::translate::plan::{DeletePlan, Plan, Search}; use crate::translate::values::emit_values; use crate::util::exprs_are_equivalent; @@ -208,67 +210,134 @@ fn emit_program_for_compound_select( // Each subselect gets their own TranslateCtx, but they share the same limit_ctx // because the LIMIT applies to the entire compound select, not just a single subselect. - let mut t_ctx_list = Vec::with_capacity(rest.len() + 1); - let reg_limit = if let Some(limit) = limit { + // The way LIMIT works with compound selects is: + // - If a given subselect appears BEFORE any UNION, then do NOT count those rows towards the LIMIT, + // because the rows from those subselects need to be deduplicated before they start being counted. + // - If a given subselect appears AFTER the last UNION, then count those rows towards the LIMIT immediately. + let limit_ctx = limit.map(|limit| { let reg = program.alloc_register(); program.emit_insn(Insn::Integer { value: limit as i64, dest: reg, }); - Some(reg) - } else { - None - }; - let limit_ctx = if let Some(reg_limit) = reg_limit { - Some(LimitCtx::new_shared(reg_limit)) - } else { - None - }; - let mut t_ctx_first = TranslateCtx::new( + LimitCtx::new_shared(reg) + }); + + // Each subselect gets their own TranslateCtx. + let mut t_ctx_list = Vec::with_capacity(rest.len() + 1); + t_ctx_list.push(TranslateCtx::new( program, syms, first.table_references.len(), first.result_columns.len(), - ); - t_ctx_first.limit_ctx = limit_ctx; - t_ctx_list.push(t_ctx_first); - - for (select, _) in rest.iter() { - let mut t_ctx = TranslateCtx::new( + )); + rest.iter().for_each(|(select, _)| { + let t_ctx = TranslateCtx::new( program, syms, select.table_references.len(), select.result_columns.len(), ); - t_ctx.limit_ctx = limit_ctx; t_ctx_list.push(t_ctx); + }); + + // Compound select operators have the same precedence and are left-associative. + // If there is any remaining UNION operator on the right side of a given sub-SELECT, + // all of the rows from the preceding UNION arms need to be deduplicated. + // This is done by creating an ephemeral index and inserting all the rows from the left side of + // the last UNION arm into it. + // Then, as soon as there are no more UNION operators left, all the deduplicated rows from the + // ephemeral index are emitted, and lastly the rows from the remaining sub-SELECTS are emitted + // as is, as they don't require deduplication. + let mut first_t_ctx = t_ctx_list.remove(0); + let requires_union_deduplication = rest + .iter() + .any(|(_, operator)| operator == &ast::CompoundOperator::Union); + if requires_union_deduplication { + // appears BEFORE a UNION operator, so do not count those rows towards the LIMIT. + first.limit = None; + } else { + // appears AFTER the last UNION operator, so count those rows towards the LIMIT. + first_t_ctx.limit_ctx = limit_ctx; } - let mut first_t_ctx = t_ctx_list.remove(0); + let mut union_dedupe_index = if requires_union_deduplication { + let dedupe_index = get_union_dedupe_index(program, &first); + first.query_type = SelectQueryType::UnionArm { + index_cursor_id: dedupe_index.0, + dedupe_index: dedupe_index.1.clone(), + }; + Some(dedupe_index) + } else { + None + }; + + // Emit the first SELECT emit_query(program, &mut first, &mut first_t_ctx)?; - // TODO: add support for UNION, EXCEPT, INTERSECT + // Emit the remaining SELECTs. Any selects on the left side of a UNION must deduplicate their + // results with the ephemeral index created above. while !t_ctx_list.is_empty() { let label_next_select = program.allocate_label(); // If the LIMIT is reached in any subselect, jump to either: // a) the IfNot of the next subselect, or // b) the end of the program - if let Some(reg_limit) = reg_limit { + if let Some(limit_ctx) = limit_ctx { program.emit_insn(Insn::IfNot { - reg: reg_limit, + reg: limit_ctx.reg_limit, target_pc: label_next_select, jump_if_null: true, }); } let mut t_ctx = t_ctx_list.remove(0); + let requires_union_deduplication = rest + .iter() + .any(|(_, operator)| operator == &ast::CompoundOperator::Union); let (mut select, operator) = rest.remove(0); - if operator != ast::CompoundOperator::UnionAll { + if operator != ast::CompoundOperator::UnionAll && operator != ast::CompoundOperator::Union { crate::bail_parse_error!("unimplemented compound select operator: {:?}", operator); } + + if requires_union_deduplication { + // Again: appears BEFORE a UNION operator, so do not count those rows towards the LIMIT. + select.limit = None; + } else { + // appears AFTER the last UNION operator, so count those rows towards the LIMIT. + t_ctx.limit_ctx = limit_ctx; + } + + if requires_union_deduplication { + select.query_type = SelectQueryType::UnionArm { + index_cursor_id: union_dedupe_index.as_ref().unwrap().0, + dedupe_index: union_dedupe_index.as_ref().unwrap().1.clone(), + }; + } else if let Some((dedupe_cursor_id, dedupe_index)) = union_dedupe_index.take() { + // When there are no more UNION operators left, all the deduplicated rows from the preceding union arms need to be emitted + // as result rows. + read_deduplicated_union_rows( + program, + dedupe_cursor_id, + dedupe_index.as_ref(), + limit_ctx, + label_next_select, + ); + } emit_query(program, &mut select, &mut t_ctx)?; program.preassign_label_to_next_insn(label_next_select); } + if let Some((dedupe_cursor_id, dedupe_index)) = union_dedupe_index { + let label_jump_over_dedupe = program.allocate_label(); + read_deduplicated_union_rows( + program, + dedupe_cursor_id, + dedupe_index.as_ref(), + limit_ctx, + label_jump_over_dedupe, + ); + program.preassign_label_to_next_insn(label_jump_over_dedupe); + } + program.epilogue(TransactionMode::Read); program.result_columns = first.result_columns; program.table_references = first.table_references; @@ -276,6 +345,84 @@ fn emit_program_for_compound_select( Ok(()) } +/// Creates an ephemeral index that will be used to deduplicate the results of any sub-selects +/// that appear before the last UNION operator. +fn get_union_dedupe_index( + program: &mut ProgramBuilder, + first_select_in_compound: &SelectPlan, +) -> (usize, Arc) { + let dedupe_index = Arc::new(Index { + columns: first_select_in_compound + .result_columns + .iter() + .map(|c| IndexColumn { + name: c + .name(&first_select_in_compound.table_references) + .map(|n| n.to_string()) + .unwrap_or_default(), + order: SortOrder::Asc, + pos_in_table: 0, + collation: None, // FIXME: this should be inferred + }) + .collect(), + name: "union_dedupe".to_string(), + root_page: 0, + ephemeral: true, + table_name: String::new(), + unique: true, + has_rowid: false, + }); + let cursor_id = program.alloc_cursor_id( + Some(dedupe_index.name.clone()), + CursorType::BTreeIndex(dedupe_index.clone()), + ); + program.emit_insn(Insn::OpenEphemeral { + cursor_id, + is_table: false, + }); + (cursor_id, dedupe_index.clone()) +} + +/// Emits the bytecode for reading deduplicated rows from the ephemeral index created for UNION operators. +fn read_deduplicated_union_rows( + program: &mut ProgramBuilder, + dedupe_cursor_id: usize, + dedupe_index: &Index, + limit_ctx: Option, + label_limit_reached: BranchOffset, +) { + let label_dedupe_next = program.allocate_label(); + let label_dedupe_loop_start = program.allocate_label(); + let dedupe_cols_start_reg = program.alloc_registers(dedupe_index.columns.len()); + program.emit_insn(Insn::Rewind { + cursor_id: dedupe_cursor_id, + pc_if_empty: label_dedupe_next, + }); + program.preassign_label_to_next_insn(label_dedupe_loop_start); + for col_idx in 0..dedupe_index.columns.len() { + program.emit_insn(Insn::Column { + cursor_id: dedupe_cursor_id, + column: col_idx, + dest: dedupe_cols_start_reg + col_idx, + }); + } + program.emit_insn(Insn::ResultRow { + start_reg: dedupe_cols_start_reg, + count: dedupe_index.columns.len(), + }); + if let Some(limit_ctx) = limit_ctx { + program.emit_insn(Insn::DecrJumpZero { + reg: limit_ctx.reg_limit, + target_pc: label_limit_reached, + }) + } + program.preassign_label_to_next_insn(label_dedupe_next); + program.emit_insn(Insn::Next { + cursor_id: dedupe_cursor_id, + pc_if_next: label_dedupe_loop_start, + }); +} + fn emit_program_for_select( program: &mut ProgramBuilder, mut plan: SelectPlan, diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 4ee6b8047..3de6b176e 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -319,6 +319,13 @@ pub enum SelectQueryType { /// The index of the first instruction in the bytecode that implements the subquery. coroutine_implementation_start: BranchOffset, }, + /// One arm of a UNION query, so its results need to be fed into a temp index for deduplication. + UnionArm { + /// The cursor ID of the temp index that will be used to deduplicate the results. + index_cursor_id: CursorID, + /// The deduplication index. + dedupe_index: Arc, + }, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/core/translate/result_row.rs b/core/translate/result_row.rs index d1466773c..03e2fcb10 100644 --- a/core/translate/result_row.rs +++ b/core/translate/result_row.rs @@ -1,5 +1,9 @@ use crate::{ - vdbe::{builder::ProgramBuilder, insn::Insn, BranchOffset}, + vdbe::{ + builder::ProgramBuilder, + insn::{IdxInsertFlags, Insn}, + BranchOffset, + }, Result, }; @@ -88,6 +92,25 @@ pub fn emit_result_row_and_limit( count: plan.result_columns.len(), }); } + SelectQueryType::UnionArm { + index_cursor_id, + dedupe_index, + } => { + let record_reg = program.alloc_register(); + program.emit_insn(Insn::MakeRecord { + start_reg: result_columns_start_reg, + count: plan.result_columns.len(), + dest_reg: record_reg, + index_name: Some(dedupe_index.name.clone()), + }); + program.emit_insn(Insn::IdxInsert { + cursor_id: *index_cursor_id, + record_reg, + unpacked_start: None, + unpacked_count: None, + flags: IdxInsertFlags::new(), + }); + } SelectQueryType::Subquery { yield_reg, .. } => { program.emit_insn(Insn::Yield { yield_reg: *yield_reg, diff --git a/core/translate/select.rs b/core/translate/select.rs index bf2112829..b96e142b6 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -100,9 +100,13 @@ pub fn prepare_select_plan<'a>( )?; let mut rest = Vec::with_capacity(compounds.len()); for CompoundSelect { select, operator } in compounds { - // TODO: add support for UNION, EXCEPT and INTERSECT - if operator != ast::CompoundOperator::UnionAll { - crate::bail_parse_error!("only UNION ALL is supported for compound SELECTs"); + // TODO: add support for EXCEPT and INTERSECT + if operator != ast::CompoundOperator::UnionAll + && operator != ast::CompoundOperator::Union + { + crate::bail_parse_error!( + "only UNION ALL and UNION are supported for compound SELECTs" + ); } let plan = prepare_one_select_plan( schema, diff --git a/core/translate/values.rs b/core/translate/values.rs index 390c6bb76..3f6faf718 100644 --- a/core/translate/values.rs +++ b/core/translate/values.rs @@ -21,6 +21,7 @@ pub fn emit_values( SelectQueryType::Subquery { yield_reg, .. } => { emit_values_in_subquery(program, plan, resolver, yield_reg)? } + SelectQueryType::UnionArm { .. } => unreachable!(), }; Ok(reg_result_cols_start) } @@ -56,6 +57,7 @@ fn emit_values_when_single_row( end_offset: BranchOffset::Offset(0), }); } + SelectQueryType::UnionArm { .. } => unreachable!(), } Ok(start_reg) } diff --git a/testing/select.test b/testing/select.test index 4b377e1bc..3162e0904 100755 --- a/testing/select.test +++ b/testing/select.test @@ -36,7 +36,7 @@ do_execsql_test select-blob-empty { } {} do_execsql_test select-blob-ascii { - SELECT x'6C696D626f'; + SELECT x'6C696D626F'; } {limbo} do_execsql_test select-blob-emoji { @@ -285,3 +285,76 @@ do_execsql_test_on_specific_db {:memory:} select-union-all-with-filters { 6 10} +do_execsql_test_on_specific_db {:memory:} select-union-1 { + CREATE TABLE t(x TEXT, y TEXT); + CREATE TABLE u(x TEXT, y TEXT); + INSERT INTO t VALUES('x','x'),('y','y'); + INSERT INTO u VALUES('x','x'),('y','y'); + + select * from t UNION select * from u; +} {x|x +y|y} + +do_execsql_test_on_specific_db {:memory:} select-union-all-union { + CREATE TABLE t(x TEXT, y TEXT); + CREATE TABLE u(x TEXT, y TEXT); + CREATE TABLE v(x TEXT, y TEXT); + INSERT INTO t VALUES('x','x'),('y','y'); + INSERT INTO u VALUES('x','x'),('y','y'); + INSERT INTO v VALUES('x','x'),('y','y'); + + select * from t UNION select * from u UNION ALL select * from v; +} {x|x +y|y +x|x +y|y} + +do_execsql_test_on_specific_db {:memory:} select-union-all-union-2 { + CREATE TABLE t(x TEXT, y TEXT); + CREATE TABLE u(x TEXT, y TEXT); + CREATE TABLE v(x TEXT, y TEXT); + INSERT INTO t VALUES('x','x'),('y','y'); + INSERT INTO u VALUES('x','x'),('y','y'); + INSERT INTO v VALUES('x','x'),('y','y'); + + select * from t UNION ALL select * from u UNION select * from v; +} {x|x +y|y} + +do_execsql_test_on_specific_db {:memory:} select-union-3 { + CREATE TABLE t(x TEXT, y TEXT); + CREATE TABLE u(x TEXT, y TEXT); + CREATE TABLE v(x TEXT, y TEXT); + INSERT INTO t VALUES('x','x'),('y','y'); + INSERT INTO u VALUES('x','x'),('y','y'); + INSERT INTO v VALUES('x','x'),('y','y'); + + select * from t UNION select * from u UNION select * from v; +} {x|x +y|y} + +do_execsql_test_on_specific_db {:memory:} select-union-4 { + CREATE TABLE t(x TEXT, y TEXT); + CREATE TABLE u(x TEXT, y TEXT); + CREATE TABLE v(x TEXT, y TEXT); + INSERT INTO t VALUES('x','x'),('y','y'); + INSERT INTO u VALUES('x','x'),('y','y'); + INSERT INTO v VALUES('x','x'),('y','y'); + + select * from t UNION select * from u UNION select * from v UNION select * from t; +} {x|x +y|y} + +do_execsql_test_on_specific_db {:memory:} select-union-all-union-3 { + CREATE TABLE t(x TEXT, y TEXT); + CREATE TABLE u(x TEXT, y TEXT); + CREATE TABLE v(x TEXT, y TEXT); + INSERT INTO t VALUES('x','x'),('y','y'); + INSERT INTO u VALUES('x','x'),('y','y'); + INSERT INTO v VALUES('x','x'),('y','y'); + + select * from t UNION select * from u UNION select * from v UNION ALL select * from t; +} {x|x +y|y +x|x +y|y} diff --git a/tests/integration/fuzz/mod.rs b/tests/integration/fuzz/mod.rs index 232348e39..7fd5d23d3 100644 --- a/tests/integration/fuzz/mod.rs +++ b/tests/integration/fuzz/mod.rs @@ -468,7 +468,7 @@ mod tests { } for iter_num in 0..NUM_FUZZ_ITERATIONS { - // Number of SELECT clauses to be UNION ALL'd + // Number of SELECT clauses let num_selects_in_union = rng.random_range(1..=(table_names.len() + MAX_SELECTS_IN_UNION_EXTRA)); let mut select_statements = Vec::new(); @@ -479,7 +479,15 @@ mod tests { select_statements.push(format!("SELECT c1, c2, c3 FROM {}", table_to_select_from)); } - let mut query = select_statements.join(" UNION ALL "); + const COMPOUND_OPERATORS: [&str; 2] = [" UNION ALL ", " UNION "]; + + let mut query = String::new(); + for (i, select_statement) in select_statements.iter().enumerate() { + if i > 0 { + query.push_str(COMPOUND_OPERATORS.choose(&mut rng).unwrap()); + } + query.push_str(select_statement); + } if rng.random_bool(0.8) { let limit_val = rng.random_range(0..=MAX_LIMIT_VALUE); // LIMIT 0 is valid @@ -497,9 +505,15 @@ mod tests { let sqlite_results = sqlite_exec_rows(&sqlite_conn, &query); assert_eq!( - limbo_results, sqlite_results, - "query: {}, limbo: {:?}, sqlite: {:?}, seed: {}", - query, limbo_results, sqlite_results, seed + limbo_results, + sqlite_results, + "query: {}, limbo.len(): {}, sqlite.len(): {}, limbo: {:?}, sqlite: {:?}, seed: {}", + query, + limbo_results.len(), + sqlite_results.len(), + limbo_results, + sqlite_results, + seed ); } }