state machine for op_column

This commit is contained in:
pedrocarlo
2025-08-11 12:48:13 -03:00
parent fe0e4bcbb7
commit 1221f65d10
2 changed files with 340 additions and 285 deletions

View File

@@ -1359,6 +1359,20 @@ fn read_varint_fast(buf: &[u8]) -> Result<(u64, usize)> {
read_varint(buf)
}
#[derive(Debug, Clone, Copy)]
pub enum OpColumnState {
Start,
Rowid {
index_cursor_id: usize,
table_cursor_id: usize,
},
Seek {
rowid: i64,
table_cursor_id: usize,
},
GetColumn,
}
pub fn op_column(
program: &Program,
state: &mut ProgramState,
@@ -1375,295 +1389,334 @@ pub fn op_column(
},
insn
);
if let Some((index_cursor_id, table_cursor_id)) = state.deferred_seeks[*cursor_id].take() {
let deferred_seek = 'd: {
let rowid = {
let mut index_cursor = state.get_cursor(index_cursor_id);
let index_cursor = index_cursor.as_btree_mut();
match index_cursor.rowid()? {
IOResult::IO => {
break 'd Some((index_cursor_id, table_cursor_id));
}
IOResult::Done(rowid) => rowid,
}
};
let mut table_cursor = state.get_cursor(table_cursor_id);
let table_cursor = table_cursor.as_btree_mut();
match table_cursor.seek(
SeekKey::TableRowId(rowid.unwrap()),
SeekOp::GE { eq_only: true },
)? {
IOResult::Done(_) => None,
IOResult::IO => Some((index_cursor_id, table_cursor_id)),
}
};
if let Some(deferred_seek) = deferred_seek {
state.deferred_seeks[*cursor_id] = Some(deferred_seek);
return Ok(InsnFunctionStepResult::IO);
}
}
let (_, cursor_type) = program.cursor_ref.get(*cursor_id).unwrap();
match cursor_type {
CursorType::BTreeTable(_) | CursorType::BTreeIndex(_) => {
'ifnull: {
let mut cursor_ref =
must_be_btree_cursor!(*cursor_id, program.cursor_ref, state, "Column");
let cursor = cursor_ref.as_btree_mut();
if cursor.get_null_flag() {
drop(cursor_ref);
state.registers[*dest] = Register::Value(Value::Null);
state.pc += 1;
return Ok(InsnFunctionStepResult::Step);
}
let record_result = return_if_io!(cursor.record());
let Some(payload) = record_result.as_ref().map(|r| r.get_payload()) else {
break 'ifnull;
};
let mut record_cursor = cursor.record_cursor.borrow_mut();
if record_cursor.offsets.is_empty() {
let (header_size, header_len_bytes) = read_varint_fast(payload)?;
let header_size = header_size as usize;
debug_assert!(header_size <= payload.len() && header_size <= 98307, "header_size: {header_size}, header_len_bytes: {header_len_bytes}, payload.len(): {}", payload.len());
record_cursor.header_size = header_size;
record_cursor.header_offset = header_len_bytes;
record_cursor.offsets.push(header_size);
}
let target_column = *column;
let mut parse_pos = record_cursor.header_offset;
let mut data_offset = record_cursor
.offsets
.last()
.copied()
.expect("header_offset must be set");
// Parse the header for serial types incrementally until we have the target column
while record_cursor.serial_types.len() <= target_column
&& parse_pos < record_cursor.header_size
'outer: loop {
match state.op_column_state {
OpColumnState::Start => {
if let Some((index_cursor_id, table_cursor_id)) =
state.deferred_seeks[*cursor_id].take()
{
let (serial_type, varint_len) = read_varint_fast(&payload[parse_pos..])?;
record_cursor.serial_types.push(serial_type);
parse_pos += varint_len;
let data_size = match serial_type {
// NULL
0 => 0,
// I8
1 => 1,
// I16
2 => 2,
// I24
3 => 3,
// I32
4 => 4,
// I48
5 => 6,
// I64
6 => 8,
// F64
7 => 8,
// CONST_INT0
8 => 0,
// CONST_INT1
9 => 0,
// BLOB
n if n >= 12 && n & 1 == 0 => (n - 12) >> 1,
// TEXT
n if n >= 13 && n & 1 == 1 => (n - 13) >> 1,
// Reserved
10 | 11 => {
return Err(LimboError::Corrupt(format!(
"Reserved serial type: {serial_type}"
)))
}
_ => unreachable!("Invalid serial type: {serial_type}"),
} as usize;
data_offset += data_size;
record_cursor.offsets.push(data_offset);
}
debug_assert!(
parse_pos <= record_cursor.header_size,
"parse_pos: {parse_pos}, header_size: {}",
record_cursor.header_size
);
record_cursor.header_offset = parse_pos;
if target_column >= record_cursor.serial_types.len() {
break 'ifnull;
}
let start_offset = record_cursor.offsets[target_column];
let end_offset = record_cursor.offsets[target_column + 1];
// SAFETY: We know that the payload is valid until the next row is processed.
let buf = unsafe {
std::mem::transmute::<&[u8], &'static [u8]>(&payload[start_offset..end_offset])
};
let serial_type = record_cursor.serial_types[target_column];
drop(record_result);
drop(record_cursor);
drop(cursor_ref);
match serial_type {
// NULL
0 => break 'ifnull,
// I8
1 => {
state.registers[*dest] =
Register::Value(Value::Integer(buf[0] as i8 as i64));
}
// I16
2 => {
state.registers[*dest] = Register::Value(Value::Integer(
i16::from_be_bytes([buf[0], buf[1]]) as i64,
));
}
// I24
3 => {
let sign_extension = (buf[0] > 0x7F) as u8 * 0xFF;
let value = Value::Integer(i32::from_be_bytes([
sign_extension,
buf[0],
buf[1],
buf[2],
]) as i64);
state.registers[*dest] = Register::Value(value);
}
// I32
4 => {
let value =
Value::Integer(
i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as i64
);
state.registers[*dest] = Register::Value(value);
}
// I48
5 => {
let sign_extension = (buf[0] > 0x7F) as u8 * 0xFF;
let value = Value::Integer(i64::from_be_bytes([
sign_extension,
sign_extension,
buf[0],
buf[1],
buf[2],
buf[3],
buf[4],
buf[5],
]));
state.registers[*dest] = Register::Value(value);
}
// I64
6 => {
let value =
Value::Integer(i64::from_be_bytes(buf[..8].try_into().unwrap()));
state.registers[*dest] = Register::Value(value);
}
// F64
7 => {
let value = Value::Float(f64::from_be_bytes(buf[..8].try_into().unwrap()));
state.registers[*dest] = Register::Value(value);
}
// CONST_INT0
8 => {
state.registers[*dest] = Register::Value(Value::Integer(0));
}
// CONST_INT1
9 => {
state.registers[*dest] = Register::Value(Value::Integer(1));
}
// BLOB
n if n >= 12 && n & 1 == 0 => {
// Try to reuse the registers when allocation is not needed.
match state.registers[*dest] {
Register::Value(Value::Blob(ref mut existing_blob)) => {
existing_blob.do_extend(&buf);
}
_ => {
state.registers[*dest] = Register::Value(Value::Blob(buf.to_vec()));
}
}
}
// TEXT
n if n >= 13 && n & 1 == 1 => {
// Try to reuse the registers when allocation is not needed.
match state.registers[*dest] {
Register::Value(Value::Text(ref mut existing_text)) => {
// SAFETY: We know the text is valid UTF-8 because we only accept valid UTF-8 and the serial type is TEXT.
let text = unsafe { std::str::from_utf8_unchecked(buf) };
existing_text.do_extend(&text);
}
_ => {
// SAFETY: We know the text is valid UTF-8 because we only accept valid UTF-8 and the serial type is TEXT.
let text = unsafe { std::str::from_utf8_unchecked(buf) };
state.registers[*dest] =
Register::Value(Value::Text(Text::new(text)));
}
}
}
_ => panic!("Invalid serial type: {serial_type}"),
}
state.pc += 1;
return Ok(InsnFunctionStepResult::Step);
};
// DEFAULT handling. Try to reuse the registers when allocation is not needed.
let Some(ref default) = default else {
state.registers[*dest] = Register::Value(Value::Null);
state.pc += 1;
return Ok(InsnFunctionStepResult::Step);
};
match (default, &mut state.registers[*dest]) {
(Value::Text(new_text), Register::Value(Value::Text(existing_text))) => {
existing_text.do_extend(new_text);
}
(Value::Blob(new_blob), Register::Value(Value::Blob(existing_blob))) => {
existing_blob.do_extend(new_blob);
}
_ => {
state.registers[*dest] = Register::Value(default.clone());
}
}
state.pc += 1;
return Ok(InsnFunctionStepResult::Step);
}
CursorType::Sorter => {
let record = {
let mut cursor = state.get_cursor(*cursor_id);
let cursor = cursor.as_sorter_mut();
cursor.record().cloned()
};
if let Some(record) = record {
state.registers[*dest] = Register::Value(match record.get_value_opt(*column) {
Some(val) => val.to_owned(),
None => default.clone().unwrap_or(Value::Null),
});
} else {
state.registers[*dest] = Register::Value(Value::Null);
}
}
CursorType::Pseudo(_) => {
let value = {
let mut cursor = state.get_cursor(*cursor_id);
let cursor = cursor.as_pseudo_mut();
if let Some(record) = cursor.record() {
record.get_value(*column)?.to_owned()
state.op_column_state = OpColumnState::Rowid {
index_cursor_id,
table_cursor_id,
};
} else {
Value::Null
state.op_column_state = OpColumnState::GetColumn;
}
};
state.registers[*dest] = Register::Value(value);
}
CursorType::VirtualTable(_) => {
panic!("Insn:Column on virtual table cursor, use Insn:VColumn instead");
}
OpColumnState::Rowid {
index_cursor_id,
table_cursor_id,
} => {
let rowid = {
let mut index_cursor = state.get_cursor(index_cursor_id);
let index_cursor = index_cursor.as_btree_mut();
return_if_io!(index_cursor.rowid())
};
state.op_column_state = OpColumnState::Seek {
rowid: rowid.unwrap(),
table_cursor_id,
};
}
OpColumnState::Seek {
rowid,
table_cursor_id,
} => {
{
let mut table_cursor = state.get_cursor(table_cursor_id);
let table_cursor = table_cursor.as_btree_mut();
return_if_io!(
table_cursor.seek(SeekKey::TableRowId(rowid), SeekOp::GE { eq_only: true })
);
}
state.op_column_state = OpColumnState::GetColumn;
}
OpColumnState::GetColumn => {
let (_, cursor_type) = program.cursor_ref.get(*cursor_id).unwrap();
match cursor_type {
CursorType::BTreeTable(_) | CursorType::BTreeIndex(_) => {
'ifnull: {
let mut cursor_ref = must_be_btree_cursor!(
*cursor_id,
program.cursor_ref,
state,
"Column"
);
let cursor = cursor_ref.as_btree_mut();
if cursor.get_null_flag() {
drop(cursor_ref);
state.registers[*dest] = Register::Value(Value::Null);
break 'outer;
}
let record_result = return_if_io!(cursor.record());
let Some(payload) = record_result.as_ref().map(|r| r.get_payload())
else {
break 'ifnull;
};
let mut record_cursor = cursor.record_cursor.borrow_mut();
if record_cursor.offsets.is_empty() {
let (header_size, header_len_bytes) = read_varint_fast(payload)?;
let header_size = header_size as usize;
debug_assert!(header_size <= payload.len() && header_size <= 98307, "header_size: {header_size}, header_len_bytes: {header_len_bytes}, payload.len(): {}", payload.len());
record_cursor.header_size = header_size;
record_cursor.header_offset = header_len_bytes;
record_cursor.offsets.push(header_size);
}
let target_column = *column;
let mut parse_pos = record_cursor.header_offset;
let mut data_offset = record_cursor
.offsets
.last()
.copied()
.expect("header_offset must be set");
// Parse the header for serial types incrementally until we have the target column
while record_cursor.serial_types.len() <= target_column
&& parse_pos < record_cursor.header_size
{
let (serial_type, varint_len) =
read_varint_fast(&payload[parse_pos..])?;
record_cursor.serial_types.push(serial_type);
parse_pos += varint_len;
let data_size = match serial_type {
// NULL
0 => 0,
// I8
1 => 1,
// I16
2 => 2,
// I24
3 => 3,
// I32
4 => 4,
// I48
5 => 6,
// I64
6 => 8,
// F64
7 => 8,
// CONST_INT0
8 => 0,
// CONST_INT1
9 => 0,
// BLOB
n if n >= 12 && n & 1 == 0 => (n - 12) >> 1,
// TEXT
n if n >= 13 && n & 1 == 1 => (n - 13) >> 1,
// Reserved
10 | 11 => {
return Err(LimboError::Corrupt(format!(
"Reserved serial type: {serial_type}"
)))
}
_ => unreachable!("Invalid serial type: {serial_type}"),
} as usize;
data_offset += data_size;
record_cursor.offsets.push(data_offset);
}
debug_assert!(
parse_pos <= record_cursor.header_size,
"parse_pos: {parse_pos}, header_size: {}",
record_cursor.header_size
);
record_cursor.header_offset = parse_pos;
if target_column >= record_cursor.serial_types.len() {
break 'ifnull;
}
let start_offset = record_cursor.offsets[target_column];
let end_offset = record_cursor.offsets[target_column + 1];
// SAFETY: We know that the payload is valid until the next row is processed.
let buf = unsafe {
std::mem::transmute::<&[u8], &'static [u8]>(
&payload[start_offset..end_offset],
)
};
let serial_type = record_cursor.serial_types[target_column];
drop(record_result);
drop(record_cursor);
drop(cursor_ref);
match serial_type {
// NULL
0 => break 'ifnull,
// I8
1 => {
state.registers[*dest] =
Register::Value(Value::Integer(buf[0] as i8 as i64));
}
// I16
2 => {
state.registers[*dest] =
Register::Value(Value::Integer(i16::from_be_bytes([
buf[0], buf[1],
])
as i64));
}
// I24
3 => {
let sign_extension = (buf[0] > 0x7F) as u8 * 0xFF;
let value = Value::Integer(i32::from_be_bytes([
sign_extension,
buf[0],
buf[1],
buf[2],
])
as i64);
state.registers[*dest] = Register::Value(value);
}
// I32
4 => {
let value = Value::Integer(i32::from_be_bytes([
buf[0], buf[1], buf[2], buf[3],
])
as i64);
state.registers[*dest] = Register::Value(value);
}
// I48
5 => {
let sign_extension = (buf[0] > 0x7F) as u8 * 0xFF;
let value = Value::Integer(i64::from_be_bytes([
sign_extension,
sign_extension,
buf[0],
buf[1],
buf[2],
buf[3],
buf[4],
buf[5],
]));
state.registers[*dest] = Register::Value(value);
}
// I64
6 => {
let value = Value::Integer(i64::from_be_bytes(
buf[..8].try_into().unwrap(),
));
state.registers[*dest] = Register::Value(value);
}
// F64
7 => {
let value = Value::Float(f64::from_be_bytes(
buf[..8].try_into().unwrap(),
));
state.registers[*dest] = Register::Value(value);
}
// CONST_INT0
8 => {
state.registers[*dest] = Register::Value(Value::Integer(0));
}
// CONST_INT1
9 => {
state.registers[*dest] = Register::Value(Value::Integer(1));
}
// BLOB
n if n >= 12 && n & 1 == 0 => {
// Try to reuse the registers when allocation is not needed.
match state.registers[*dest] {
Register::Value(Value::Blob(ref mut existing_blob)) => {
existing_blob.do_extend(&buf);
}
_ => {
state.registers[*dest] =
Register::Value(Value::Blob(buf.to_vec()));
}
}
}
// TEXT
n if n >= 13 && n & 1 == 1 => {
// Try to reuse the registers when allocation is not needed.
match state.registers[*dest] {
Register::Value(Value::Text(ref mut existing_text)) => {
// SAFETY: We know the text is valid UTF-8 because we only accept valid UTF-8 and the serial type is TEXT.
let text =
unsafe { std::str::from_utf8_unchecked(buf) };
existing_text.do_extend(&text);
}
_ => {
// SAFETY: We know the text is valid UTF-8 because we only accept valid UTF-8 and the serial type is TEXT.
let text =
unsafe { std::str::from_utf8_unchecked(buf) };
state.registers[*dest] =
Register::Value(Value::Text(Text::new(text)));
}
}
}
_ => panic!("Invalid serial type: {serial_type}"),
}
break 'outer;
};
// DEFAULT handling. Try to reuse the registers when allocation is not needed.
let Some(ref default) = default else {
state.registers[*dest] = Register::Value(Value::Null);
break;
};
match (default, &mut state.registers[*dest]) {
(
Value::Text(new_text),
Register::Value(Value::Text(existing_text)),
) => {
existing_text.do_extend(new_text);
}
(
Value::Blob(new_blob),
Register::Value(Value::Blob(existing_blob)),
) => {
existing_blob.do_extend(new_blob);
}
_ => {
state.registers[*dest] = Register::Value(default.clone());
}
}
break;
}
CursorType::Sorter => {
let record = {
let mut cursor = state.get_cursor(*cursor_id);
let cursor = cursor.as_sorter_mut();
cursor.record().cloned()
};
if let Some(record) = record {
state.registers[*dest] =
Register::Value(match record.get_value_opt(*column) {
Some(val) => val.to_owned(),
None => default.clone().unwrap_or(Value::Null),
});
} else {
state.registers[*dest] = Register::Value(Value::Null);
}
}
CursorType::Pseudo(_) => {
let value = {
let mut cursor = state.get_cursor(*cursor_id);
let cursor = cursor.as_pseudo_mut();
if let Some(record) = cursor.record() {
record.get_value(*column)?.to_owned()
} else {
Value::Null
}
};
state.registers[*dest] = Register::Value(value);
}
CursorType::VirtualTable(_) => {
panic!("Insn:Column on virtual table cursor, use Insn:VColumn instead");
}
}
break;
}
}
}
state.op_column_state = OpColumnState::Start;
state.pc += 1;
Ok(InsnFunctionStepResult::Step)
}

View File

@@ -32,8 +32,8 @@ use crate::{
translate::{collate::CollationSeq, plan::TableReferences},
types::{IOResult, RawSlice, TextRef},
vdbe::execute::{
OpDeleteState, OpDeleteSubState, OpIdxInsertState, OpInsertState, OpInsertSubState,
OpNewRowidState, OpNoConflictState, OpSeekState,
OpColumnState, OpDeleteState, OpDeleteSubState, OpIdxInsertState, OpInsertState,
OpInsertSubState, OpNewRowidState, OpNoConflictState, OpSeekState,
},
RefValue,
};
@@ -264,6 +264,7 @@ pub struct ProgramState {
seek_state: OpSeekState,
/// Current collation sequence set by OP_CollSeq instruction
current_collation: Option<CollationSeq>,
op_column_state: OpColumnState,
}
impl ProgramState {
@@ -303,6 +304,7 @@ impl ProgramState {
op_no_conflict_state: OpNoConflictState::Start,
seek_state: OpSeekState::Start,
current_collation: None,
op_column_state: OpColumnState::Start,
}
}