diff --git a/core/btree.rs b/core/btree.rs index 7ed1ceda9..1c4dedd9c 100644 --- a/core/btree.rs +++ b/core/btree.rs @@ -1,6 +1,6 @@ use crate::pager::Pager; use crate::sqlite3_ondisk::{BTreeCell, TableInteriorCell, TableLeafCell}; -use crate::types::Record; +use crate::types::OwnedRecord; use anyhow::Result; @@ -42,7 +42,7 @@ pub struct Cursor { root_page: usize, page: RefCell>>, rowid: RefCell>, - record: RefCell>, + record: RefCell>, } impl Cursor { @@ -97,7 +97,7 @@ impl Cursor { Ok(self.rowid.borrow()) } - pub fn record(&self) -> Result>> { + pub fn record(&self) -> Result>> { Ok(self.record.borrow()) } @@ -105,7 +105,7 @@ impl Cursor { self.record.borrow().is_some() } - fn get_next_record(&mut self) -> Result, Option)>> { + fn get_next_record(&mut self) -> Result, Option)>> { loop { let mem_page = { let mem_page = self.page.borrow(); @@ -152,7 +152,7 @@ impl Cursor { } BTreeCell::TableLeafCell(TableLeafCell { _rowid, _payload }) => { mem_page.advance(); - let record = crate::sqlite3_ondisk::read_record(_payload)?; + let record= crate::sqlite3_ondisk::read_record(_payload)?; return Ok(CursorResult::Ok((Some(*_rowid), Some(record)))); } } diff --git a/core/lib.rs b/core/lib.rs index 0bae65495..628431145 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -100,10 +100,7 @@ impl Connection { match cmd { Cmd::Stmt(stmt) => { let program = Arc::new(translate::translate(&self.schema, stmt)?); - Ok(Statement { - program, - pager: self.pager.clone(), - }) + Ok(Statement::new(program, self.pager.clone())) } Cmd::Explain(_stmt) => todo!(), Cmd::ExplainQueryPlan(_stmt) => todo!(), @@ -121,8 +118,8 @@ impl Connection { match cmd { Cmd::Stmt(stmt) => { let program = Arc::new(translate::translate(&self.schema, stmt)?); - let state = vdbe::ProgramState::new(program.max_registers); - Ok(Some(Rows::new(state, program, self.pager.clone()))) + let stmt = Statement::new(program, self.pager.clone()); + Ok(Some(Rows { stmt })) } Cmd::Explain(stmt) => { let program = translate::translate(&self.schema, stmt)?; @@ -160,40 +157,21 @@ impl Connection { pub struct Statement { program: Arc, + state: vdbe::ProgramState, pager: Arc, } impl Statement { - pub fn query(&self) -> Result { - let state = vdbe::ProgramState::new(self.program.max_registers); - Ok(Rows::new(state, self.program.clone(), self.pager.clone())) - } - - pub fn reset(&self) {} -} - -pub enum RowResult { - Row(Row), - IO, - Done, -} - -pub struct Rows { - state: vdbe::ProgramState, - program: Arc, - pager: Arc, -} - -impl Rows { - pub fn new(state: vdbe::ProgramState, program: Arc, pager: Arc) -> Self { + pub fn new(program: Arc, pager: Arc) -> Self { + let state = vdbe::ProgramState::new(program.max_registers); Self { - state, program, + state, pager, } } - pub fn next(&mut self) -> Result { + pub fn step<'a>(&'a mut self) -> Result> { loop { let result = self.program.step(&mut self.state, self.pager.clone())?; match result { @@ -209,15 +187,42 @@ impl Rows { } } } + + pub fn query(&mut self) -> Result { + let stmt = Statement::new(self.program.clone(), self.pager.clone()); + Ok(Rows::new(stmt)) + } + + pub fn reset(&self) {} } -pub struct Row { - pub values: Vec, +pub enum RowResult<'a> { + Row(Row<'a>), + IO, + Done, } -impl Row { +pub struct Row<'a> { + pub values: Vec>, +} + +impl<'a> Row<'a> { pub fn get(&self, idx: usize) -> Result { let value = &self.values[idx]; T::from_value(value) } } + +pub struct Rows { + stmt: Statement, +} + +impl Rows { + pub fn new(stmt: Statement) -> Self { + Self { stmt } + } + + pub fn next<'a>(&'a mut self) -> Result> { + self.stmt.step() + } +} diff --git a/core/sqlite3_ondisk.rs b/core/sqlite3_ondisk.rs index b6a8174ad..4de3278f2 100644 --- a/core/sqlite3_ondisk.rs +++ b/core/sqlite3_ondisk.rs @@ -26,10 +26,9 @@ use crate::buffer_pool::BufferPool; use crate::io::{Buffer, Completion}; use crate::pager::Page; -use crate::types::{Record, Value}; +use crate::types::{OwnedRecord, OwnedValue}; use crate::PageSource; use anyhow::{anyhow, Result}; -use std::rc::Rc; use std::sync::{Arc, Mutex}; use log::trace; @@ -296,7 +295,7 @@ impl TryFrom for SerialType { } } -pub fn read_record(payload: &[u8]) -> Result { +pub fn read_record(payload: &[u8]) -> Result { let mut pos = 0; let (header_size, nr) = read_varint(payload)?; assert!((header_size as usize) >= nr); @@ -318,24 +317,24 @@ pub fn read_record(payload: &[u8]) -> Result { pos += usize; values.push(value); } - Ok(Record { values }) + Ok(OwnedRecord { values }) } -pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize)> { +pub fn read_value(buf: & [u8], serial_type: &SerialType) -> Result<(OwnedValue, usize)> { match *serial_type { - SerialType::Null => Ok((Value::Null, 0)), + SerialType::Null => Ok((OwnedValue::Null, 0)), SerialType::UInt8 => { if buf.len() < 1 { return Err(anyhow!("Invalid UInt8 value")); } - Ok((Value::Integer(buf[0] as i64), 1)) + Ok((OwnedValue::Integer(buf[0] as i64), 1)) } SerialType::BEInt16 => { if buf.len() < 2 { return Err(anyhow!("Invalid BEInt16 value")); } Ok(( - Value::Integer(i16::from_be_bytes([buf[0], buf[1]]) as i64), + OwnedValue::Integer(i16::from_be_bytes([buf[0], buf[1]]) as i64), 2, )) } @@ -344,7 +343,7 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize) return Err(anyhow!("Invalid BEInt24 value")); } Ok(( - Value::Integer(i32::from_be_bytes([0, buf[0], buf[1], buf[2]]) as i64), + OwnedValue::Integer(i32::from_be_bytes([0, buf[0], buf[1], buf[2]]) as i64), 3, )) } @@ -353,7 +352,7 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize) return Err(anyhow!("Invalid BEInt32 value")); } Ok(( - Value::Integer(i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as i64), + OwnedValue::Integer(i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as i64), 4, )) } @@ -362,7 +361,7 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize) return Err(anyhow!("Invalid BEInt48 value")); } Ok(( - Value::Integer(i64::from_be_bytes([ + OwnedValue::Integer(i64::from_be_bytes([ 0, 0, buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], ])), 6, @@ -373,7 +372,7 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize) return Err(anyhow!("Invalid BEInt64 value")); } Ok(( - Value::Integer(i64::from_be_bytes([ + OwnedValue::Integer(i64::from_be_bytes([ buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], ])), 8, @@ -384,19 +383,19 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize) return Err(anyhow!("Invalid BEFloat64 value")); } Ok(( - Value::Float(f64::from_be_bytes([ + OwnedValue::Float(f64::from_be_bytes([ buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], ])), 8, )) } - SerialType::ConstInt0 => Ok((Value::Integer(0), 0)), - SerialType::ConstInt1 => Ok((Value::Integer(1), 0)), + SerialType::ConstInt0 => Ok((OwnedValue::Integer(0), 0)), + SerialType::ConstInt1 => Ok((OwnedValue::Integer(1), 0)), SerialType::Blob(n) => { if buf.len() < n { return Err(anyhow!("Invalid Blob value")); } - Ok((Value::Blob(buf[0..n].to_vec()), n)) + Ok((OwnedValue::Blob(buf[0..n].to_vec()), n)) } SerialType::String(n) => { if buf.len() < n { @@ -404,7 +403,7 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize) } let bytes = buf[0..n].to_vec(); let value = unsafe { String::from_utf8_unchecked(bytes) }; - Ok((Value::Text(Rc::new(value)), n)) + Ok((OwnedValue::Text(value), n)) } } } @@ -458,22 +457,22 @@ mod tests { } #[rstest] - #[case(&[], SerialType::Null, Value::Null)] - #[case(&[255], SerialType::UInt8, Value::Integer(255))] - #[case(&[0x12, 0x34], SerialType::BEInt16, Value::Integer(0x1234))] - #[case(&[0x12, 0x34, 0x56], SerialType::BEInt24, Value::Integer(0x123456))] - #[case(&[0x12, 0x34, 0x56, 0x78], SerialType::BEInt32, Value::Integer(0x12345678))] - #[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC], SerialType::BEInt48, Value::Integer(0x123456789ABC))] - #[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xFF], SerialType::BEInt64, Value::Integer(0x123456789ABCDEFF))] - #[case(&[64, 9, 33, 251, 84, 68, 45, 24], SerialType::BEFloat64, Value::Float(3.141592653589793))] - #[case(&[], SerialType::ConstInt0, Value::Integer(0))] - #[case(&[], SerialType::ConstInt1, Value::Integer(1))] - #[case(&[1, 2, 3], SerialType::Blob(3), Value::Blob(vec![1, 2, 3]))] - #[case(&[65, 66, 67], SerialType::String(3), Value::Text("ABC".to_string().into()))] + #[case(&[], SerialType::Null, OwnedValue::Null)] + #[case(&[255], SerialType::UInt8, OwnedValue::Integer(255))] + #[case(&[0x12, 0x34], SerialType::BEInt16, OwnedValue::Integer(0x1234))] + #[case(&[0x12, 0x34, 0x56], SerialType::BEInt24, OwnedValue::Integer(0x123456))] + #[case(&[0x12, 0x34, 0x56, 0x78], SerialType::BEInt32, OwnedValue::Integer(0x12345678))] + #[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC], SerialType::BEInt48, OwnedValue::Integer(0x123456789ABC))] + #[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xFF], SerialType::BEInt64, OwnedValue::Integer(0x123456789ABCDEFF))] + #[case(&[64, 9, 33, 251, 84, 68, 45, 24], SerialType::BEFloat64, OwnedValue::Float(3.141592653589793))] + #[case(&[], SerialType::ConstInt0, OwnedValue::Integer(0))] + #[case(&[], SerialType::ConstInt1, OwnedValue::Integer(1))] + #[case(&[1, 2, 3], SerialType::Blob(3), OwnedValue::Blob(vec![1, 2, 3]))] + #[case(&[65, 66, 67], SerialType::String(3), OwnedValue::Text("ABC".to_string()))] fn test_read_value( #[case] buf: &[u8], #[case] serial_type: SerialType, - #[case] expected: Value, + #[case] expected: OwnedValue, ) { let result = read_value(buf, &serial_type).unwrap(); assert_eq!(result, (expected, buf.len())); diff --git a/core/types.rs b/core/types.rs index 071af17e0..62ce931e4 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1,16 +1,33 @@ -use std::rc::Rc; - use anyhow::Result; #[derive(Debug, Clone, PartialEq)] -pub enum Value { +pub enum Value<'a> { Null, Integer(i64), Float(f64), - Text(Rc), + Text(&'a String), + Blob(&'a Vec), +} + +#[derive(Debug, Clone, PartialEq)] +pub enum OwnedValue { + Null, + Integer(i64), + Float(f64), + Text(String), Blob(Vec), } +pub fn to_value<'a>(value: &'a OwnedValue) -> Value<'a> { + match value { + OwnedValue::Null => Value::Null, + OwnedValue::Integer(i) => Value::Integer(*i), + OwnedValue::Float(f) => Value::Float(*f), + OwnedValue::Text(s) => Value::Text(s), + OwnedValue::Blob(b) => Value::Blob(b), + } +} + pub trait FromValue { fn from_value(value: &Value) -> Result where @@ -36,12 +53,22 @@ impl FromValue for String { } #[derive(Debug)] -pub struct Record { - pub values: Vec, +pub struct Record<'a> { + pub values: Vec>, } -impl Record { - pub fn new(values: Vec) -> Self { +impl<'a> Record<'a> { + pub fn new(values: Vec>) -> Self { Self { values } } } + +pub struct OwnedRecord { + pub values: Vec, +} + +impl OwnedRecord { + pub fn new(values: Vec) -> Self { + Self { values } + } +} \ No newline at end of file diff --git a/core/vdbe.rs b/core/vdbe.rs index 3c53a2d7e..9467bdc00 100644 --- a/core/vdbe.rs +++ b/core/vdbe.rs @@ -1,8 +1,9 @@ use crate::btree::{Cursor, CursorResult}; use crate::pager::Pager; -use crate::types::{Record, Value}; +use crate::types::{OwnedValue, Record}; use anyhow::Result; +use std::cell::RefCell; use std::collections::BTreeMap; use std::sync::Arc; @@ -143,24 +144,24 @@ impl ProgramBuilder { } } -pub enum StepResult { +pub enum StepResult<'a> { Done, IO, - Row(Record), + Row(Record<'a>), } /// The program state describes the environment in which the program executes. pub struct ProgramState { pub pc: usize, - cursors: BTreeMap, - registers: Vec, + cursors: RefCell>, + registers: Vec, } impl ProgramState { pub fn new(max_registers: usize) -> Self { - let cursors = BTreeMap::new(); + let cursors = RefCell::new(BTreeMap::new()); let mut registers = Vec::with_capacity(max_registers); - registers.resize(max_registers, Value::Null); + registers.resize(max_registers, OwnedValue::Null); Self { pc: 0, cursors, @@ -191,10 +192,11 @@ impl Program { } } - pub fn step(&self, state: &mut ProgramState, pager: Arc) -> Result { + pub fn step<'a>(&self, state: &'a mut ProgramState, pager: Arc) -> Result> { loop { let insn = &self.insns[state.pc]; trace_insn(state.pc, insn); + let mut cursors = state.cursors.borrow_mut(); match insn { Insn::Init { target_pc } => { state.pc = *target_pc; @@ -204,14 +206,14 @@ impl Program { root_page, } => { let cursor = Cursor::new(pager.clone(), *root_page); - state.cursors.insert(*cursor_id, cursor); + cursors.insert(*cursor_id, cursor); state.pc += 1; } Insn::OpenReadAwait => { state.pc += 1; } Insn::RewindAsync { cursor_id } => { - let cursor = state.cursors.get_mut(cursor_id).unwrap(); + let cursor = cursors.get_mut(cursor_id).unwrap(); match cursor.rewind()? { CursorResult::Ok(()) => {} CursorResult::IO => { @@ -225,7 +227,7 @@ impl Program { cursor_id, pc_if_empty, } => { - let cursor = state.cursors.get_mut(cursor_id).unwrap(); + let cursor = cursors.get_mut(cursor_id).unwrap(); cursor.wait_for_completion()?; if cursor.is_empty() { state.pc = *pc_if_empty; @@ -238,7 +240,7 @@ impl Program { column, dest, } => { - let cursor = state.cursors.get_mut(cursor_id).unwrap(); + let cursor = cursors.get_mut(cursor_id).unwrap(); if let Some(ref record) = *cursor.record()? { state.registers[*dest] = record.values[*column].clone(); } else { @@ -252,13 +254,13 @@ impl Program { } => { let mut values = Vec::with_capacity(*register_end - *register_start); for i in *register_start..*register_end { - values.push(state.registers[i].clone()); + values.push(crate::types::to_value(&state.registers[i])); } state.pc += 1; return Ok(StepResult::Row(Record::new(values))); } Insn::NextAsync { cursor_id } => { - let cursor = state.cursors.get_mut(cursor_id).unwrap(); + let cursor = cursors.get_mut(cursor_id).unwrap(); match cursor.next()? { CursorResult::Ok(_) => {} CursorResult::IO => { @@ -272,7 +274,7 @@ impl Program { cursor_id, pc_if_next, } => { - let cursor = state.cursors.get_mut(cursor_id).unwrap(); + let cursor = cursors.get_mut(cursor_id).unwrap(); cursor.wait_for_completion()?; if cursor.has_record() { state.pc = *pc_if_next; @@ -290,22 +292,22 @@ impl Program { state.pc = *target_pc; } Insn::Integer { value, dest } => { - state.registers[*dest] = Value::Integer(*value); + state.registers[*dest] = OwnedValue::Integer(*value); state.pc += 1; } Insn::RowId { cursor_id, dest } => { - let cursor = state.cursors.get_mut(cursor_id).unwrap(); + let cursor = cursors.get_mut(cursor_id).unwrap(); if let Some(ref rowid) = *cursor.rowid()? { - state.registers[*dest] = Value::Integer(*rowid as i64); + state.registers[*dest] = OwnedValue::Integer(*rowid as i64); } else { todo!(); } state.pc += 1; } Insn::DecrJumpZero { reg, target_pc } => match state.registers[*reg] { - Value::Integer(n) => { + OwnedValue::Integer(n) => { if n > 0 { - state.registers[*reg] = Value::Integer(n - 1); + state.registers[*reg] = OwnedValue::Integer(n - 1); state.pc += 1; } else { state.pc = *target_pc;