use bitfield for ended_coroutine

This commit is contained in:
Jussi Saurio
2025-01-26 13:12:47 +02:00
parent 9e32ce6c77
commit b687cf66eb

View File

@@ -238,6 +238,29 @@ fn get_cursor_as_sorter_mut<'long, 'short>(
cursor
}
struct Bitfield<const N: usize>([u64; N]);
impl<const N: usize> Bitfield<N> {
fn new() -> Self {
Self([0; N])
}
fn set(&mut self, bit: usize) {
assert!(bit < N * 64, "bit out of bounds");
self.0[bit / 64] |= 1 << (bit % 64);
}
fn unset(&mut self, bit: usize) {
assert!(bit < N * 64, "bit out of bounds");
self.0[bit / 64] &= !(1 << (bit % 64));
}
fn get(&self, bit: usize) -> bool {
assert!(bit < N * 64, "bit out of bounds");
(self.0[bit / 64] & (1 << (bit % 64))) != 0
}
}
/// The program state describes the environment in which the program executes.
pub struct ProgramState {
pub pc: InsnReference,
@@ -245,7 +268,7 @@ pub struct ProgramState {
registers: Vec<OwnedValue>,
last_compare: Option<std::cmp::Ordering>,
deferred_seek: Option<(CursorID, CursorID)>,
ended_coroutine: HashMap<usize, bool>, // flag to indicate that a coroutine has ended (key is the yield register)
ended_coroutine: Bitfield<4>, // flag to indicate that a coroutine has ended (key is the yield register. currently we assume that the yield register is always between 0-255, YOLO)
regex_cache: RegexCache,
interrupted: bool,
parameters: HashMap<NonZero<usize>, OwnedValue>,
@@ -262,7 +285,7 @@ impl ProgramState {
registers,
last_compare: None,
deferred_seek: None,
ended_coroutine: HashMap::new(),
ended_coroutine: Bitfield::new(),
regex_cache: RegexCache::new(),
interrupted: false,
parameters: HashMap::new(),
@@ -301,7 +324,7 @@ impl ProgramState {
self.registers.resize(max_registers, OwnedValue::Null);
self.last_compare = None;
self.deferred_seek = None;
self.ended_coroutine.clear();
self.ended_coroutine.0 = [0; 4];
self.regex_cache.like.clear();
self.interrupted = false;
self.parameters.clear();
@@ -2086,7 +2109,7 @@ impl Program {
assert!(jump_on_definition.is_offset());
let start_offset = start_offset.to_offset_int();
state.registers[*yield_reg] = OwnedValue::Integer(start_offset as i64);
state.ended_coroutine.insert(*yield_reg, false);
state.ended_coroutine.unset(*yield_reg);
let jump_on_definition = jump_on_definition.to_offset_int();
state.pc = if jump_on_definition == 0 {
state.pc + 1
@@ -2096,7 +2119,7 @@ impl Program {
}
Insn::EndCoroutine { yield_reg } => {
if let OwnedValue::Integer(pc) = state.registers[*yield_reg] {
state.ended_coroutine.insert(*yield_reg, true);
state.ended_coroutine.set(*yield_reg);
let pc: u32 = pc
.try_into()
.unwrap_or_else(|_| panic!("EndCoroutine: pc overflow: {}", pc));
@@ -2110,11 +2133,7 @@ impl Program {
end_offset,
} => {
if let OwnedValue::Integer(pc) = state.registers[*yield_reg] {
if *state
.ended_coroutine
.get(yield_reg)
.expect("coroutine not initialized")
{
if state.ended_coroutine.get(*yield_reg) {
state.pc = end_offset.to_offset_int();
} else {
let pc: u32 = pc
@@ -3403,7 +3422,7 @@ mod tests {
exec_ltrim, exec_max, exec_min, exec_nullif, exec_quote, exec_random, exec_randomblob,
exec_round, exec_rtrim, exec_sign, exec_soundex, exec_substring, exec_trim, exec_typeof,
exec_unhex, exec_unicode, exec_upper, exec_zeroblob, execute_sqlite_version, AggContext,
OwnedValue,
Bitfield, OwnedValue,
};
use std::{collections::HashMap, rc::Rc};
@@ -4292,4 +4311,23 @@ mod tests {
expected_str
);
}
#[test]
fn test_bitfield() {
let mut bitfield = Bitfield::<4>::new();
for i in 0..256 {
bitfield.set(i);
assert!(bitfield.get(i));
for j in 0..i {
assert!(bitfield.get(j));
}
for j in i + 1..256 {
assert!(!bitfield.get(j));
}
}
for i in 0..256 {
bitfield.unset(i);
assert!(!bitfield.get(i));
}
}
}