Support non-utf8 blobs in .clone command

This commit is contained in:
PThorpe92
2025-08-14 21:30:17 -04:00
parent 2b289157d0
commit 9ccf79111a
2 changed files with 102 additions and 70 deletions

View File

@@ -1360,55 +1360,28 @@ impl Limbo {
}
}
}
// quoted select list and table name
// FIXME: sqlite has logic to check rowid and optionally preserve it, but it requires
// pragma index_list, and it seems to be relevant only for indexes.
let cols_str = cols
.iter()
.map(|c| quote_ident(c))
.collect::<Vec<_>>()
.join(", ");
let select = format!("SELECT {cols_str} FROM {}", quote_ident(table_name));
// FIXME: sqlite has logic to check rowid and optionally preserve it, but it requires
// pragma index_list, and it seems to be relevant only for indexes.
if let Some(mut rows) = conn.query(select)? {
loop {
match rows.step()? {
StepResult::Row => {
let row = rows.row().unwrap();
let mut vals = Vec::with_capacity(types.len());
for (i, t) in types.iter().enumerate() {
write!(out, "INSERT INTO {} VALUES(", quote_ident(table_name))?;
for i in 0..cols.len() {
if i > 0 {
out.write_all(b",")?;
}
let v = row.get::<&Value>(i)?;
let s =
if t.contains("CHAR") || t.contains("CLOB") || t.contains("TEXT") {
let mut s = String::new();
s.push('\'');
s.push_str(&v.to_string().replace('\'', "''"));
s.push('\'');
s
} else if t.contains("BLOB") {
match v {
Value::Blob(b) => {
let mut s = String::with_capacity(2 + b.len() * 2);
s.push_str("X'");
for byte in b {
use std::fmt::Write as _;
let _ = write!(&mut s, "{byte:02x}");
}
s.push('\'');
s
}
_ => "X''".to_string(),
}
} else {
v.to_string()
};
vals.push(s);
Self::write_sql_value_from_value(out, v)?;
}
writeln!(
out,
"INSERT INTO {} VALUES({});",
quote_ident(table_name),
vals.join(",")
)?;
out.write_all(b");\n")?;
}
StepResult::IO => rows.run_once()?,
StepResult::Done | StepResult::Interrupt => break,
@@ -1496,6 +1469,37 @@ impl Limbo {
Ok(())
}
fn write_sql_value_from_value<W: Write>(out: &mut W, v: &Value) -> io::Result<()> {
match v {
Value::Null => out.write_all(b"NULL"),
Value::Integer(i) => out.write_all(format!("{i}").as_bytes()),
Value::Float(f) => write!(out, "{f}").map(|_| ()),
Value::Text(s) => {
out.write_all(b"'")?;
let bytes = &s.value;
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if b == b'\'' {
out.write_all(b"''")?;
} else {
out.write_all(&[b])?;
}
i += 1;
}
out.write_all(b"'")
}
Value::Blob(b) => {
out.write_all(b"X'")?;
const HEX: &[u8; 16] = b"0123456789abcdef";
for &byte in b {
out.write_all(&[HEX[(byte >> 4) as usize], HEX[(byte & 0x0F) as usize]])?;
}
out.write_all(b"'")
}
}
}
fn dump_database(&mut self) -> anyhow::Result<()> {
// Move writer out so we dont hold a field-borrow of self during the call.
let mut out = std::mem::take(&mut self.writer).unwrap();

View File

@@ -183,52 +183,83 @@ pub fn get_io(db_location: DbLocation, io_choice: &str) -> anyhow::Result<Arc<dy
pub struct ApplyWriter<'a> {
target: &'a Arc<turso_core::Connection>,
// accumulates until we see a statement terminator
buf: String,
// accumulate raw bytes to support non-utf8 BLOB types
buf: Vec<u8>,
}
impl<'a> ApplyWriter<'a> {
pub fn new(target: &'a Arc<turso_core::Connection>) -> Self {
Self {
target,
buf: String::new(),
buf: Vec::new(),
}
}
// Find the next statement terminator ;\n or ;\r\n in a byte buffer.
// Returns (end_idx_inclusive, drain_len), where drain_len includes the newline(s).
fn find_stmt_end(buf: &[u8]) -> Option<(usize, usize)> {
let mut i = 0;
while i < buf.len() {
// Look for ';'
if buf[i] == b';' {
// Accept ;\n
if i + 1 < buf.len() && buf[i + 1] == b'\n' {
return Some((i, 2));
}
// Accept ;\r\n
if i + 2 < buf.len() && buf[i + 1] == b'\r' && buf[i + 2] == b'\n' {
return Some((i, 3));
}
}
i += 1;
}
None
}
pub fn flush_complete_statements(&mut self) -> io::Result<()> {
// We emit statements with ";\n". Split conservatively on that
while let Some(idx) = self.buf.find(";\n") {
let stmt = self.buf[..idx + 1].to_string();
self.exec_stmt(&stmt)
.map_err(|e| io::Error::other(e.to_string()))?;
self.buf.drain(..idx + 2);
while let Some((end_inclusive, drain_len)) = Self::find_stmt_end(&self.buf) {
// Copy stmt bytes [0..=end_inclusive]
let stmt_bytes = self.buf[..=end_inclusive].to_vec();
// Drain including the trailing newline(s)
self.buf.drain(..end_inclusive + drain_len);
self.exec_stmt_bytes(&stmt_bytes)?;
}
Ok(())
}
// Handle final trailing statement that ends with ';' followed only by ASCII whitespace.
pub fn finish(mut self) -> io::Result<()> {
// Handle a trailing statement missing the final newline
if let Some(idx) = self.buf.rfind(';') {
if self.buf[idx..].starts_with(';') && self.buf[idx + 1..].trim().is_empty() {
let stmt = self.buf[..idx + 1].to_string();
self.exec_stmt(&stmt)
.map_err(|e| io::Error::other(e.to_string()))?;
// Skip if buffer empty or no ';'
if let Some(semicolon_pos) = self.buf.iter().rposition(|&b| b == b';') {
// Are all bytes after ';' ASCII whitespace?
if self.buf[semicolon_pos + 1..]
.iter()
.all(|&b| matches!(b, b' ' | b'\t' | b'\r' | b'\n'))
{
let stmt_bytes = self.buf[..=semicolon_pos].to_vec();
self.buf.clear();
self.exec_stmt_bytes(&stmt_bytes)?;
}
}
Ok(())
}
fn exec_stmt_bytes(&self, stmt_bytes: &[u8]) -> io::Result<()> {
// SQL must be UTF-8. If not, surface a clear error.
let sql = std::str::from_utf8(stmt_bytes).map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, format!("non-UTF8 SQL: {e}"))
})?;
self.exec_stmt(sql)
.map_err(|e| io::Error::other(e.to_string()))
}
fn exec_stmt(&self, sql: &str) -> Result<(), LimboError> {
match self.target.query(sql) {
Ok(Some(mut rows)) => loop {
match rows.step()? {
StepResult::Row => {}
StepResult::IO => {
rows.run_once()?;
}
StepResult::Done => break,
StepResult::Interrupt => break,
StepResult::IO => rows.run_once()?,
StepResult::Done | StepResult::Interrupt => break,
StepResult::Busy => {
return Err(LimboError::InternalError("target database is busy".into()))
}
@@ -241,6 +272,17 @@ impl<'a> ApplyWriter<'a> {
}
}
impl<'a> Write for ApplyWriter<'a> {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
self.buf.extend_from_slice(data);
self.flush_complete_statements()?;
Ok(data.len())
}
fn flush(&mut self) -> io::Result<()> {
self.flush_complete_statements()
}
}
pub trait ProgressSink {
fn on<S: Display>(&mut self, _p: S) {}
}
@@ -256,20 +298,6 @@ impl ProgressSink for StderrProgress {
}
}
impl<'a> Write for ApplyWriter<'a> {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
// TODO: for now .dump only writes valid UTF-8
self.buf.push_str(
std::str::from_utf8(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
);
self.flush_complete_statements()?;
Ok(data.len())
}
fn flush(&mut self) -> io::Result<()> {
self.flush_complete_statements()
}
}
pub const BEFORE_HELP_MSG: &str = r#"
Turso SQL Shell Help