diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index e68246770..5830d9e83 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -1359,6 +1359,20 @@ fn read_varint_fast(buf: &[u8]) -> Result<(u64, usize)> { read_varint(buf) } +#[derive(Debug, Clone, Copy)] +pub enum OpColumnState { + Start, + Rowid { + index_cursor_id: usize, + table_cursor_id: usize, + }, + Seek { + rowid: i64, + table_cursor_id: usize, + }, + GetColumn, +} + pub fn op_column( program: &Program, state: &mut ProgramState, @@ -1375,295 +1389,334 @@ pub fn op_column( }, insn ); - if let Some((index_cursor_id, table_cursor_id)) = state.deferred_seeks[*cursor_id].take() { - let deferred_seek = 'd: { - let rowid = { - let mut index_cursor = state.get_cursor(index_cursor_id); - let index_cursor = index_cursor.as_btree_mut(); - match index_cursor.rowid()? { - IOResult::IO => { - break 'd Some((index_cursor_id, table_cursor_id)); - } - IOResult::Done(rowid) => rowid, - } - }; - let mut table_cursor = state.get_cursor(table_cursor_id); - let table_cursor = table_cursor.as_btree_mut(); - match table_cursor.seek( - SeekKey::TableRowId(rowid.unwrap()), - SeekOp::GE { eq_only: true }, - )? { - IOResult::Done(_) => None, - IOResult::IO => Some((index_cursor_id, table_cursor_id)), - } - }; - if let Some(deferred_seek) = deferred_seek { - state.deferred_seeks[*cursor_id] = Some(deferred_seek); - return Ok(InsnFunctionStepResult::IO); - } - } - - let (_, cursor_type) = program.cursor_ref.get(*cursor_id).unwrap(); - match cursor_type { - CursorType::BTreeTable(_) | CursorType::BTreeIndex(_) => { - 'ifnull: { - let mut cursor_ref = - must_be_btree_cursor!(*cursor_id, program.cursor_ref, state, "Column"); - let cursor = cursor_ref.as_btree_mut(); - - if cursor.get_null_flag() { - drop(cursor_ref); - state.registers[*dest] = Register::Value(Value::Null); - state.pc += 1; - return Ok(InsnFunctionStepResult::Step); - } - - let record_result = return_if_io!(cursor.record()); - let Some(payload) = record_result.as_ref().map(|r| r.get_payload()) else { - break 'ifnull; - }; - - let mut record_cursor = cursor.record_cursor.borrow_mut(); - - if record_cursor.offsets.is_empty() { - let (header_size, header_len_bytes) = read_varint_fast(payload)?; - let header_size = header_size as usize; - - debug_assert!(header_size <= payload.len() && header_size <= 98307, "header_size: {header_size}, header_len_bytes: {header_len_bytes}, payload.len(): {}", payload.len()); - - record_cursor.header_size = header_size; - record_cursor.header_offset = header_len_bytes; - record_cursor.offsets.push(header_size); - } - - let target_column = *column; - let mut parse_pos = record_cursor.header_offset; - let mut data_offset = record_cursor - .offsets - .last() - .copied() - .expect("header_offset must be set"); - - // Parse the header for serial types incrementally until we have the target column - while record_cursor.serial_types.len() <= target_column - && parse_pos < record_cursor.header_size + 'outer: loop { + match state.op_column_state { + OpColumnState::Start => { + if let Some((index_cursor_id, table_cursor_id)) = + state.deferred_seeks[*cursor_id].take() { - let (serial_type, varint_len) = read_varint_fast(&payload[parse_pos..])?; - record_cursor.serial_types.push(serial_type); - parse_pos += varint_len; - let data_size = match serial_type { - // NULL - 0 => 0, - // I8 - 1 => 1, - // I16 - 2 => 2, - // I24 - 3 => 3, - // I32 - 4 => 4, - // I48 - 5 => 6, - // I64 - 6 => 8, - // F64 - 7 => 8, - // CONST_INT0 - 8 => 0, - // CONST_INT1 - 9 => 0, - // BLOB - n if n >= 12 && n & 1 == 0 => (n - 12) >> 1, - // TEXT - n if n >= 13 && n & 1 == 1 => (n - 13) >> 1, - // Reserved - 10 | 11 => { - return Err(LimboError::Corrupt(format!( - "Reserved serial type: {serial_type}" - ))) - } - _ => unreachable!("Invalid serial type: {serial_type}"), - } as usize; - data_offset += data_size; - record_cursor.offsets.push(data_offset); - } - - debug_assert!( - parse_pos <= record_cursor.header_size, - "parse_pos: {parse_pos}, header_size: {}", - record_cursor.header_size - ); - record_cursor.header_offset = parse_pos; - if target_column >= record_cursor.serial_types.len() { - break 'ifnull; - } - - let start_offset = record_cursor.offsets[target_column]; - let end_offset = record_cursor.offsets[target_column + 1]; - // SAFETY: We know that the payload is valid until the next row is processed. - let buf = unsafe { - std::mem::transmute::<&[u8], &'static [u8]>(&payload[start_offset..end_offset]) - }; - let serial_type = record_cursor.serial_types[target_column]; - drop(record_result); - drop(record_cursor); - drop(cursor_ref); - - match serial_type { - // NULL - 0 => break 'ifnull, - // I8 - 1 => { - state.registers[*dest] = - Register::Value(Value::Integer(buf[0] as i8 as i64)); - } - // I16 - 2 => { - state.registers[*dest] = Register::Value(Value::Integer( - i16::from_be_bytes([buf[0], buf[1]]) as i64, - )); - } - // I24 - 3 => { - let sign_extension = (buf[0] > 0x7F) as u8 * 0xFF; - let value = Value::Integer(i32::from_be_bytes([ - sign_extension, - buf[0], - buf[1], - buf[2], - ]) as i64); - state.registers[*dest] = Register::Value(value); - } - // I32 - 4 => { - let value = - Value::Integer( - i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as i64 - ); - state.registers[*dest] = Register::Value(value); - } - // I48 - 5 => { - let sign_extension = (buf[0] > 0x7F) as u8 * 0xFF; - let value = Value::Integer(i64::from_be_bytes([ - sign_extension, - sign_extension, - buf[0], - buf[1], - buf[2], - buf[3], - buf[4], - buf[5], - ])); - state.registers[*dest] = Register::Value(value); - } - // I64 - 6 => { - let value = - Value::Integer(i64::from_be_bytes(buf[..8].try_into().unwrap())); - state.registers[*dest] = Register::Value(value); - } - // F64 - 7 => { - let value = Value::Float(f64::from_be_bytes(buf[..8].try_into().unwrap())); - state.registers[*dest] = Register::Value(value); - } - // CONST_INT0 - 8 => { - state.registers[*dest] = Register::Value(Value::Integer(0)); - } - // CONST_INT1 - 9 => { - state.registers[*dest] = Register::Value(Value::Integer(1)); - } - // BLOB - n if n >= 12 && n & 1 == 0 => { - // Try to reuse the registers when allocation is not needed. - match state.registers[*dest] { - Register::Value(Value::Blob(ref mut existing_blob)) => { - existing_blob.do_extend(&buf); - } - _ => { - state.registers[*dest] = Register::Value(Value::Blob(buf.to_vec())); - } - } - } - // TEXT - n if n >= 13 && n & 1 == 1 => { - // Try to reuse the registers when allocation is not needed. - match state.registers[*dest] { - Register::Value(Value::Text(ref mut existing_text)) => { - // SAFETY: We know the text is valid UTF-8 because we only accept valid UTF-8 and the serial type is TEXT. - let text = unsafe { std::str::from_utf8_unchecked(buf) }; - existing_text.do_extend(&text); - } - _ => { - // SAFETY: We know the text is valid UTF-8 because we only accept valid UTF-8 and the serial type is TEXT. - let text = unsafe { std::str::from_utf8_unchecked(buf) }; - state.registers[*dest] = - Register::Value(Value::Text(Text::new(text))); - } - } - } - _ => panic!("Invalid serial type: {serial_type}"), - } - - state.pc += 1; - return Ok(InsnFunctionStepResult::Step); - }; - - // DEFAULT handling. Try to reuse the registers when allocation is not needed. - let Some(ref default) = default else { - state.registers[*dest] = Register::Value(Value::Null); - state.pc += 1; - return Ok(InsnFunctionStepResult::Step); - }; - match (default, &mut state.registers[*dest]) { - (Value::Text(new_text), Register::Value(Value::Text(existing_text))) => { - existing_text.do_extend(new_text); - } - (Value::Blob(new_blob), Register::Value(Value::Blob(existing_blob))) => { - existing_blob.do_extend(new_blob); - } - _ => { - state.registers[*dest] = Register::Value(default.clone()); - } - } - state.pc += 1; - return Ok(InsnFunctionStepResult::Step); - } - CursorType::Sorter => { - let record = { - let mut cursor = state.get_cursor(*cursor_id); - let cursor = cursor.as_sorter_mut(); - cursor.record().cloned() - }; - if let Some(record) = record { - state.registers[*dest] = Register::Value(match record.get_value_opt(*column) { - Some(val) => val.to_owned(), - None => default.clone().unwrap_or(Value::Null), - }); - } else { - state.registers[*dest] = Register::Value(Value::Null); - } - } - CursorType::Pseudo(_) => { - let value = { - let mut cursor = state.get_cursor(*cursor_id); - let cursor = cursor.as_pseudo_mut(); - if let Some(record) = cursor.record() { - record.get_value(*column)?.to_owned() + state.op_column_state = OpColumnState::Rowid { + index_cursor_id, + table_cursor_id, + }; } else { - Value::Null + state.op_column_state = OpColumnState::GetColumn; } - }; - state.registers[*dest] = Register::Value(value); - } - CursorType::VirtualTable(_) => { - panic!("Insn:Column on virtual table cursor, use Insn:VColumn instead"); + } + OpColumnState::Rowid { + index_cursor_id, + table_cursor_id, + } => { + let rowid = { + let mut index_cursor = state.get_cursor(index_cursor_id); + let index_cursor = index_cursor.as_btree_mut(); + return_if_io!(index_cursor.rowid()) + }; + state.op_column_state = OpColumnState::Seek { + rowid: rowid.unwrap(), + table_cursor_id, + }; + } + OpColumnState::Seek { + rowid, + table_cursor_id, + } => { + { + let mut table_cursor = state.get_cursor(table_cursor_id); + let table_cursor = table_cursor.as_btree_mut(); + return_if_io!( + table_cursor.seek(SeekKey::TableRowId(rowid), SeekOp::GE { eq_only: true }) + ); + } + state.op_column_state = OpColumnState::GetColumn; + } + OpColumnState::GetColumn => { + let (_, cursor_type) = program.cursor_ref.get(*cursor_id).unwrap(); + match cursor_type { + CursorType::BTreeTable(_) | CursorType::BTreeIndex(_) => { + 'ifnull: { + let mut cursor_ref = must_be_btree_cursor!( + *cursor_id, + program.cursor_ref, + state, + "Column" + ); + let cursor = cursor_ref.as_btree_mut(); + + if cursor.get_null_flag() { + drop(cursor_ref); + state.registers[*dest] = Register::Value(Value::Null); + break 'outer; + } + + let record_result = return_if_io!(cursor.record()); + let Some(payload) = record_result.as_ref().map(|r| r.get_payload()) + else { + break 'ifnull; + }; + + let mut record_cursor = cursor.record_cursor.borrow_mut(); + + if record_cursor.offsets.is_empty() { + let (header_size, header_len_bytes) = read_varint_fast(payload)?; + let header_size = header_size as usize; + + debug_assert!(header_size <= payload.len() && header_size <= 98307, "header_size: {header_size}, header_len_bytes: {header_len_bytes}, payload.len(): {}", payload.len()); + + record_cursor.header_size = header_size; + record_cursor.header_offset = header_len_bytes; + record_cursor.offsets.push(header_size); + } + + let target_column = *column; + let mut parse_pos = record_cursor.header_offset; + let mut data_offset = record_cursor + .offsets + .last() + .copied() + .expect("header_offset must be set"); + + // Parse the header for serial types incrementally until we have the target column + while record_cursor.serial_types.len() <= target_column + && parse_pos < record_cursor.header_size + { + let (serial_type, varint_len) = + read_varint_fast(&payload[parse_pos..])?; + record_cursor.serial_types.push(serial_type); + parse_pos += varint_len; + let data_size = match serial_type { + // NULL + 0 => 0, + // I8 + 1 => 1, + // I16 + 2 => 2, + // I24 + 3 => 3, + // I32 + 4 => 4, + // I48 + 5 => 6, + // I64 + 6 => 8, + // F64 + 7 => 8, + // CONST_INT0 + 8 => 0, + // CONST_INT1 + 9 => 0, + // BLOB + n if n >= 12 && n & 1 == 0 => (n - 12) >> 1, + // TEXT + n if n >= 13 && n & 1 == 1 => (n - 13) >> 1, + // Reserved + 10 | 11 => { + return Err(LimboError::Corrupt(format!( + "Reserved serial type: {serial_type}" + ))) + } + _ => unreachable!("Invalid serial type: {serial_type}"), + } as usize; + data_offset += data_size; + record_cursor.offsets.push(data_offset); + } + + debug_assert!( + parse_pos <= record_cursor.header_size, + "parse_pos: {parse_pos}, header_size: {}", + record_cursor.header_size + ); + record_cursor.header_offset = parse_pos; + if target_column >= record_cursor.serial_types.len() { + break 'ifnull; + } + + let start_offset = record_cursor.offsets[target_column]; + let end_offset = record_cursor.offsets[target_column + 1]; + // SAFETY: We know that the payload is valid until the next row is processed. + let buf = unsafe { + std::mem::transmute::<&[u8], &'static [u8]>( + &payload[start_offset..end_offset], + ) + }; + let serial_type = record_cursor.serial_types[target_column]; + drop(record_result); + drop(record_cursor); + drop(cursor_ref); + + match serial_type { + // NULL + 0 => break 'ifnull, + // I8 + 1 => { + state.registers[*dest] = + Register::Value(Value::Integer(buf[0] as i8 as i64)); + } + // I16 + 2 => { + state.registers[*dest] = + Register::Value(Value::Integer(i16::from_be_bytes([ + buf[0], buf[1], + ]) + as i64)); + } + // I24 + 3 => { + let sign_extension = (buf[0] > 0x7F) as u8 * 0xFF; + let value = Value::Integer(i32::from_be_bytes([ + sign_extension, + buf[0], + buf[1], + buf[2], + ]) + as i64); + state.registers[*dest] = Register::Value(value); + } + // I32 + 4 => { + let value = Value::Integer(i32::from_be_bytes([ + buf[0], buf[1], buf[2], buf[3], + ]) + as i64); + state.registers[*dest] = Register::Value(value); + } + // I48 + 5 => { + let sign_extension = (buf[0] > 0x7F) as u8 * 0xFF; + let value = Value::Integer(i64::from_be_bytes([ + sign_extension, + sign_extension, + buf[0], + buf[1], + buf[2], + buf[3], + buf[4], + buf[5], + ])); + state.registers[*dest] = Register::Value(value); + } + // I64 + 6 => { + let value = Value::Integer(i64::from_be_bytes( + buf[..8].try_into().unwrap(), + )); + state.registers[*dest] = Register::Value(value); + } + // F64 + 7 => { + let value = Value::Float(f64::from_be_bytes( + buf[..8].try_into().unwrap(), + )); + state.registers[*dest] = Register::Value(value); + } + // CONST_INT0 + 8 => { + state.registers[*dest] = Register::Value(Value::Integer(0)); + } + // CONST_INT1 + 9 => { + state.registers[*dest] = Register::Value(Value::Integer(1)); + } + // BLOB + n if n >= 12 && n & 1 == 0 => { + // Try to reuse the registers when allocation is not needed. + match state.registers[*dest] { + Register::Value(Value::Blob(ref mut existing_blob)) => { + existing_blob.do_extend(&buf); + } + _ => { + state.registers[*dest] = + Register::Value(Value::Blob(buf.to_vec())); + } + } + } + // TEXT + n if n >= 13 && n & 1 == 1 => { + // Try to reuse the registers when allocation is not needed. + match state.registers[*dest] { + Register::Value(Value::Text(ref mut existing_text)) => { + // SAFETY: We know the text is valid UTF-8 because we only accept valid UTF-8 and the serial type is TEXT. + let text = + unsafe { std::str::from_utf8_unchecked(buf) }; + existing_text.do_extend(&text); + } + _ => { + // SAFETY: We know the text is valid UTF-8 because we only accept valid UTF-8 and the serial type is TEXT. + let text = + unsafe { std::str::from_utf8_unchecked(buf) }; + state.registers[*dest] = + Register::Value(Value::Text(Text::new(text))); + } + } + } + _ => panic!("Invalid serial type: {serial_type}"), + } + + break 'outer; + }; + + // DEFAULT handling. Try to reuse the registers when allocation is not needed. + let Some(ref default) = default else { + state.registers[*dest] = Register::Value(Value::Null); + break; + }; + match (default, &mut state.registers[*dest]) { + ( + Value::Text(new_text), + Register::Value(Value::Text(existing_text)), + ) => { + existing_text.do_extend(new_text); + } + ( + Value::Blob(new_blob), + Register::Value(Value::Blob(existing_blob)), + ) => { + existing_blob.do_extend(new_blob); + } + _ => { + state.registers[*dest] = Register::Value(default.clone()); + } + } + break; + } + CursorType::Sorter => { + let record = { + let mut cursor = state.get_cursor(*cursor_id); + let cursor = cursor.as_sorter_mut(); + cursor.record().cloned() + }; + if let Some(record) = record { + state.registers[*dest] = + Register::Value(match record.get_value_opt(*column) { + Some(val) => val.to_owned(), + None => default.clone().unwrap_or(Value::Null), + }); + } else { + state.registers[*dest] = Register::Value(Value::Null); + } + } + CursorType::Pseudo(_) => { + let value = { + let mut cursor = state.get_cursor(*cursor_id); + let cursor = cursor.as_pseudo_mut(); + if let Some(record) = cursor.record() { + record.get_value(*column)?.to_owned() + } else { + Value::Null + } + }; + state.registers[*dest] = Register::Value(value); + } + CursorType::VirtualTable(_) => { + panic!("Insn:Column on virtual table cursor, use Insn:VColumn instead"); + } + } + break; + } } } + state.op_column_state = OpColumnState::Start; state.pc += 1; Ok(InsnFunctionStepResult::Step) } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 9760c7d60..3312442cb 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -32,8 +32,8 @@ use crate::{ translate::{collate::CollationSeq, plan::TableReferences}, types::{IOResult, RawSlice, TextRef}, vdbe::execute::{ - OpDeleteState, OpDeleteSubState, OpIdxInsertState, OpInsertState, OpInsertSubState, - OpNewRowidState, OpNoConflictState, OpSeekState, + OpColumnState, OpDeleteState, OpDeleteSubState, OpIdxInsertState, OpInsertState, + OpInsertSubState, OpNewRowidState, OpNoConflictState, OpSeekState, }, RefValue, }; @@ -264,6 +264,7 @@ pub struct ProgramState { seek_state: OpSeekState, /// Current collation sequence set by OP_CollSeq instruction current_collation: Option, + op_column_state: OpColumnState, } impl ProgramState { @@ -303,6 +304,7 @@ impl ProgramState { op_no_conflict_state: OpNoConflictState::Start, seek_state: OpSeekState::Start, current_collation: None, + op_column_state: OpColumnState::Start, } }