diff --git a/cli/app.rs b/cli/app.rs index d3dbe04e0..0af361b40 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -1360,55 +1360,28 @@ impl Limbo { } } } - // quoted select list and table name + // FIXME: sqlite has logic to check rowid and optionally preserve it, but it requires + // pragma index_list, and it seems to be relevant only for indexes. let cols_str = cols .iter() .map(|c| quote_ident(c)) .collect::>() .join(", "); let select = format!("SELECT {cols_str} FROM {}", quote_ident(table_name)); - // FIXME: sqlite has logic to check rowid and optionally preserve it, but it requires - // pragma index_list, and it seems to be relevant only for indexes. if let Some(mut rows) = conn.query(select)? { loop { match rows.step()? { StepResult::Row => { let row = rows.row().unwrap(); - let mut vals = Vec::with_capacity(types.len()); - for (i, t) in types.iter().enumerate() { + write!(out, "INSERT INTO {} VALUES(", quote_ident(table_name))?; + for i in 0..cols.len() { + if i > 0 { + out.write_all(b",")?; + } let v = row.get::<&Value>(i)?; - let s = - if t.contains("CHAR") || t.contains("CLOB") || t.contains("TEXT") { - let mut s = String::new(); - s.push('\''); - s.push_str(&v.to_string().replace('\'', "''")); - s.push('\''); - s - } else if t.contains("BLOB") { - match v { - Value::Blob(b) => { - let mut s = String::with_capacity(2 + b.len() * 2); - s.push_str("X'"); - for byte in b { - use std::fmt::Write as _; - let _ = write!(&mut s, "{byte:02x}"); - } - s.push('\''); - s - } - _ => "X''".to_string(), - } - } else { - v.to_string() - }; - vals.push(s); + Self::write_sql_value_from_value(out, v)?; } - writeln!( - out, - "INSERT INTO {} VALUES({});", - quote_ident(table_name), - vals.join(",") - )?; + out.write_all(b");\n")?; } StepResult::IO => rows.run_once()?, StepResult::Done | StepResult::Interrupt => break, @@ -1496,6 +1469,37 @@ impl Limbo { Ok(()) } + fn write_sql_value_from_value(out: &mut W, v: &Value) -> io::Result<()> { + match v { + Value::Null => out.write_all(b"NULL"), + Value::Integer(i) => out.write_all(format!("{i}").as_bytes()), + Value::Float(f) => write!(out, "{f}").map(|_| ()), + Value::Text(s) => { + out.write_all(b"'")?; + let bytes = &s.value; + let mut i = 0; + while i < bytes.len() { + let b = bytes[i]; + if b == b'\'' { + out.write_all(b"''")?; + } else { + out.write_all(&[b])?; + } + i += 1; + } + out.write_all(b"'") + } + Value::Blob(b) => { + out.write_all(b"X'")?; + const HEX: &[u8; 16] = b"0123456789abcdef"; + for &byte in b { + out.write_all(&[HEX[(byte >> 4) as usize], HEX[(byte & 0x0F) as usize]])?; + } + out.write_all(b"'") + } + } + } + fn dump_database(&mut self) -> anyhow::Result<()> { // Move writer out so we don’t hold a field-borrow of self during the call. let mut out = std::mem::take(&mut self.writer).unwrap(); diff --git a/cli/input.rs b/cli/input.rs index 3e4cb177a..bf7c7603d 100644 --- a/cli/input.rs +++ b/cli/input.rs @@ -183,52 +183,83 @@ pub fn get_io(db_location: DbLocation, io_choice: &str) -> anyhow::Result { target: &'a Arc, - // accumulates until we see a statement terminator - buf: String, + // accumulate raw bytes to support non-utf8 BLOB types + buf: Vec, } impl<'a> ApplyWriter<'a> { pub fn new(target: &'a Arc) -> Self { Self { target, - buf: String::new(), + buf: Vec::new(), } } + // Find the next statement terminator ;\n or ;\r\n in a byte buffer. + // Returns (end_idx_inclusive, drain_len), where drain_len includes the newline(s). + fn find_stmt_end(buf: &[u8]) -> Option<(usize, usize)> { + let mut i = 0; + while i < buf.len() { + // Look for ';' + if buf[i] == b';' { + // Accept ;\n + if i + 1 < buf.len() && buf[i + 1] == b'\n' { + return Some((i, 2)); + } + // Accept ;\r\n + if i + 2 < buf.len() && buf[i + 1] == b'\r' && buf[i + 2] == b'\n' { + return Some((i, 3)); + } + } + i += 1; + } + None + } + pub fn flush_complete_statements(&mut self) -> io::Result<()> { - // We emit statements with ";\n". Split conservatively on that - while let Some(idx) = self.buf.find(";\n") { - let stmt = self.buf[..idx + 1].to_string(); - self.exec_stmt(&stmt) - .map_err(|e| io::Error::other(e.to_string()))?; - self.buf.drain(..idx + 2); + while let Some((end_inclusive, drain_len)) = Self::find_stmt_end(&self.buf) { + // Copy stmt bytes [0..=end_inclusive] + let stmt_bytes = self.buf[..=end_inclusive].to_vec(); + // Drain including the trailing newline(s) + self.buf.drain(..end_inclusive + drain_len); + self.exec_stmt_bytes(&stmt_bytes)?; } Ok(()) } + // Handle final trailing statement that ends with ';' followed only by ASCII whitespace. pub fn finish(mut self) -> io::Result<()> { - // Handle a trailing statement missing the final newline - if let Some(idx) = self.buf.rfind(';') { - if self.buf[idx..].starts_with(';') && self.buf[idx + 1..].trim().is_empty() { - let stmt = self.buf[..idx + 1].to_string(); - self.exec_stmt(&stmt) - .map_err(|e| io::Error::other(e.to_string()))?; + // Skip if buffer empty or no ';' + if let Some(semicolon_pos) = self.buf.iter().rposition(|&b| b == b';') { + // Are all bytes after ';' ASCII whitespace? + if self.buf[semicolon_pos + 1..] + .iter() + .all(|&b| matches!(b, b' ' | b'\t' | b'\r' | b'\n')) + { + let stmt_bytes = self.buf[..=semicolon_pos].to_vec(); self.buf.clear(); + self.exec_stmt_bytes(&stmt_bytes)?; } } Ok(()) } + fn exec_stmt_bytes(&self, stmt_bytes: &[u8]) -> io::Result<()> { + // SQL must be UTF-8. If not, surface a clear error. + let sql = std::str::from_utf8(stmt_bytes).map_err(|e| { + io::Error::new(io::ErrorKind::InvalidData, format!("non-UTF8 SQL: {e}")) + })?; + self.exec_stmt(sql) + .map_err(|e| io::Error::other(e.to_string())) + } + fn exec_stmt(&self, sql: &str) -> Result<(), LimboError> { match self.target.query(sql) { Ok(Some(mut rows)) => loop { match rows.step()? { StepResult::Row => {} - StepResult::IO => { - rows.run_once()?; - } - StepResult::Done => break, - StepResult::Interrupt => break, + StepResult::IO => rows.run_once()?, + StepResult::Done | StepResult::Interrupt => break, StepResult::Busy => { return Err(LimboError::InternalError("target database is busy".into())) } @@ -241,6 +272,17 @@ impl<'a> ApplyWriter<'a> { } } +impl<'a> Write for ApplyWriter<'a> { + fn write(&mut self, data: &[u8]) -> io::Result { + self.buf.extend_from_slice(data); + self.flush_complete_statements()?; + Ok(data.len()) + } + fn flush(&mut self) -> io::Result<()> { + self.flush_complete_statements() + } +} + pub trait ProgressSink { fn on(&mut self, _p: S) {} } @@ -256,20 +298,6 @@ impl ProgressSink for StderrProgress { } } -impl<'a> Write for ApplyWriter<'a> { - fn write(&mut self, data: &[u8]) -> io::Result { - // TODO: for now .dump only writes valid UTF-8 - self.buf.push_str( - std::str::from_utf8(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?, - ); - self.flush_complete_statements()?; - Ok(data.len()) - } - fn flush(&mut self) -> io::Result<()> { - self.flush_complete_statements() - } -} - pub const BEFORE_HELP_MSG: &str = r#" Turso SQL Shell Help