From 6d55cdba3b67eecfb12c843565779845968e1bd2 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Mon, 24 Feb 2025 12:30:38 -0500 Subject: [PATCH] Remove allocations from numeric text casting, cleanups --- core/util.rs | 451 ++-------------------------------------------- core/vdbe/insn.rs | 16 +- core/vdbe/mod.rs | 140 +++++++------- 3 files changed, 88 insertions(+), 519 deletions(-) diff --git a/core/util.rs b/core/util.rs index 7bc73a42f..aa91836fd 100644 --- a/core/util.rs +++ b/core/util.rs @@ -1,13 +1,22 @@ -use core::num::IntErrorKind; use limbo_sqlite3_parser::ast::{self, CreateTableBody, Expr, FunctionTail, Literal}; use std::{rc::Rc, sync::Arc}; use crate::{ schema::{self, Column, Schema, Type}, - types::OwnedValue, LimboError, OpenFlags, Result, Statement, StepResult, IO, }; +pub trait RoundToPrecision { + fn round_to_precision(self, precision: f64) -> f64; +} + +impl RoundToPrecision for f64 { + fn round_to_precision(self, precision: f64) -> f64 { + let factor = 10f64.powf(precision); + (self * factor).round() / factor + } +} + // https://sqlite.org/lang_keywords.html const QUOTE_PAIRS: &[(char, char)] = &[('"', '"'), ('[', ']'), ('`', '`')]; @@ -604,184 +613,6 @@ pub fn decode_percent(uri: &str) -> String { String::from_utf8_lossy(&decoded).to_string() } -#[derive(Debug, PartialEq)] -/// Reference: -/// https://github.com/sqlite/sqlite/blob/master/src/util.c#L798 -pub enum CastTextToIntResultCode { - NotInt = -1, - Success = 0, - ExcessSpace = 1, - TooLargeOrMalformed = 2, - #[allow(dead_code)] - SpecialCase = 3, -} - -pub fn text_to_integer(text: &str) -> (OwnedValue, CastTextToIntResultCode) { - let text = text.trim(); - if text.is_empty() { - return (OwnedValue::Integer(0), CastTextToIntResultCode::NotInt); - } - let mut accum = String::new(); - let mut sign = false; - let mut has_digit = false; - let mut excess_space = false; - - let chars = text.chars(); - - for c in chars { - match c { - '0'..='9' => { - has_digit = true; - accum.push(c); - } - '+' | '-' if !has_digit && !sign => { - sign = true; - accum.push(c); - } - _ => { - excess_space = true; - break; - } - } - } - - match accum.parse::() { - Ok(num) => { - if excess_space { - return ( - OwnedValue::Integer(num), - CastTextToIntResultCode::ExcessSpace, - ); - } - - return (OwnedValue::Integer(num), CastTextToIntResultCode::Success); - } - Err(e) => match e.kind() { - IntErrorKind::NegOverflow | IntErrorKind::PosOverflow => ( - OwnedValue::Integer(0), - CastTextToIntResultCode::TooLargeOrMalformed, - ), - _ => (OwnedValue::Integer(0), CastTextToIntResultCode::NotInt), - }, - } -} - -#[derive(Debug, PartialEq)] -/// Reference -/// https://github.com/sqlite/sqlite/blob/master/src/util.c#L529 -pub enum CastTextToRealResultCode { - PureInt = 1, - HasDecimal = 2, - NotValid = 0, - NotValidButPrefix = -1, -} - -pub fn text_to_real(text: &str) -> (OwnedValue, CastTextToRealResultCode) { - let text = text.trim(); - if text.is_empty() { - return (OwnedValue::Float(0.0), CastTextToRealResultCode::NotValid); - } - let mut accum = String::new(); - let mut has_decimal_separator = false; - let mut sign = false; - let mut exp_sign = false; - let mut has_exponent = false; - let mut has_digit = false; - let mut has_decimal_digit = false; - let mut excess_space = false; - - let mut chars = text.chars(); - - 'outer: while let Some(c) = chars.next() { - match c { - '0'..='9' if !has_decimal_separator => { - has_digit = true; - accum.push(c); - } - '0'..='9' => { - // This pattern is used for both decimal and exponent digits - has_decimal_digit = true; - accum.push(c); - } - '+' | '-' if !has_digit && !sign => { - sign = true; - accum.push(c); - } - '.' if !has_decimal_separator => { - // Check if next char is a number - if let Some(ch) = chars.next() { - match ch { - '0'..='9' => { - has_decimal_separator = true; - accum.push(c); - accum.push(ch); - } - _ => { - excess_space = true; - break; - } - } - } else { - excess_space = true; - } - } - 'E' | 'e' if !has_exponent && (!has_decimal_separator || has_decimal_digit) => { - // Lookahead if next char is a number or sign - let mut curr_sign = None; - loop { - if let Some(ch) = chars.next() { - match ch { - '0'..='9' => { - has_exponent = true; - accum.push(c); - if let Some(sign) = curr_sign { - exp_sign = true; - accum.push(sign); - } - accum.push(ch); - break; - } - '+' | '-' => { - curr_sign = Some(ch); - } - _ => { - excess_space = true; - break 'outer; - } - } - } else { - excess_space = true; - break 'outer; - } - } - } - _ => { - excess_space = true; - break; - } - } - } - - if let Ok(num) = accum.parse::() { - if !has_decimal_separator && !exp_sign && !has_exponent && !excess_space { - return (OwnedValue::Float(num), CastTextToRealResultCode::PureInt); - } - - if excess_space { - // TODO see if this branch satisfies: not a valid number, but has a valid prefix which - // includes a decimal point and/or an eNNN clause - return ( - OwnedValue::Float(num), - CastTextToRealResultCode::NotValidButPrefix, - ); - } - - return (OwnedValue::Float(num), CastTextToRealResultCode::HasDecimal); - } - - return (OwnedValue::Float(0.0), CastTextToRealResultCode::NotValid); -} - #[cfg(test)] pub mod tests { use super::*; @@ -1332,264 +1163,4 @@ pub mod tests { "/home/user/db.sqlite" ); } - - #[test] - fn test_text_to_integer() { - assert_eq!( - text_to_integer("1"), - (OwnedValue::Integer(1), CastTextToIntResultCode::Success), - ); - assert_eq!( - text_to_integer("-1"), - (OwnedValue::Integer(-1), CastTextToIntResultCode::Success), - ); - assert_eq!( - text_to_integer("10000000"), - ( - OwnedValue::Integer(10000000), - CastTextToIntResultCode::Success, - ), - ); - assert_eq!( - text_to_integer("-10000000"), - ( - OwnedValue::Integer(-10000000), - CastTextToIntResultCode::Success, - ), - ); - assert_eq!( - text_to_integer("xxx"), - (OwnedValue::Integer(0), CastTextToIntResultCode::NotInt), - ); - assert_eq!( - text_to_integer("123xxx"), - ( - OwnedValue::Integer(123), - CastTextToIntResultCode::ExcessSpace, - ), - ); - assert_eq!( - text_to_integer("9223372036854775807"), - ( - OwnedValue::Integer(i64::MAX), - CastTextToIntResultCode::Success, - ), - ); - assert_eq!( - text_to_integer("9223372036854775808"), - ( - OwnedValue::Integer(0), - CastTextToIntResultCode::TooLargeOrMalformed, - ), - ); - assert_eq!( - text_to_integer("-9223372036854775808"), - ( - OwnedValue::Integer(i64::MIN), - CastTextToIntResultCode::Success, - ), - ); - assert_eq!( - text_to_integer("-9223372036854775809"), - ( - OwnedValue::Integer(0), - CastTextToIntResultCode::TooLargeOrMalformed, - ), - ); - assert_eq!( - text_to_integer("-"), - (OwnedValue::Integer(0), CastTextToIntResultCode::NotInt,), - ); - } - - #[test] - fn test_text_to_real() { - assert_eq!( - text_to_real("1"), - (OwnedValue::Float(1.0), CastTextToRealResultCode::PureInt), - ); - assert_eq!( - text_to_real("-1"), - (OwnedValue::Float(-1.0), CastTextToRealResultCode::PureInt), - ); - assert_eq!( - text_to_real("1.0"), - (OwnedValue::Float(1.0), CastTextToRealResultCode::HasDecimal), - ); - assert_eq!( - text_to_real("-1.0"), - ( - OwnedValue::Float(-1.0), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("1e10"), - ( - OwnedValue::Float(1e10), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("-1e10"), - ( - OwnedValue::Float(-1e10), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("1e-10"), - ( - OwnedValue::Float(1e-10), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("-1e-10"), - ( - OwnedValue::Float(-1e-10), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("1.123e10"), - ( - OwnedValue::Float(1.123e10), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("-1.123e10"), - ( - OwnedValue::Float(-1.123e10), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("1.123e-10"), - ( - OwnedValue::Float(1.123e-10), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("-1.123e-10"), - ( - OwnedValue::Float(-1.123e-10), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("1-282584294928"), - ( - OwnedValue::Float(1.0), - CastTextToRealResultCode::NotValidButPrefix - ), - ); - assert_eq!( - text_to_real("xxx"), - (OwnedValue::Float(0.0), CastTextToRealResultCode::NotValid), - ); - assert_eq!( - text_to_real("1.7976931348623157e308"), - ( - OwnedValue::Float(f64::MAX), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("1.7976931348623157e309"), - ( - OwnedValue::Float(f64::INFINITY), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("-1.7976931348623157e308"), - ( - OwnedValue::Float(f64::MIN), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("-1.7976931348623157e309"), - ( - OwnedValue::Float(f64::NEG_INFINITY), - CastTextToRealResultCode::HasDecimal, - ), - ); - assert_eq!( - text_to_real("1E"), - ( - OwnedValue::Float(1.0), - CastTextToRealResultCode::NotValidButPrefix, - ), - ); - assert_eq!( - text_to_real("1EE"), - ( - OwnedValue::Float(1.0), - CastTextToRealResultCode::NotValidButPrefix, - ), - ); - assert_eq!( - text_to_real("-1E"), - ( - OwnedValue::Float(-1.0), - CastTextToRealResultCode::NotValidButPrefix, - ), - ); - assert_eq!( - text_to_real("1."), - ( - OwnedValue::Float(1.0), - CastTextToRealResultCode::NotValidButPrefix, - ), - ); - assert_eq!( - text_to_real("-1."), - ( - OwnedValue::Float(-1.0), - CastTextToRealResultCode::NotValidButPrefix, - ), - ); - assert_eq!( - text_to_real("1.23E"), - ( - OwnedValue::Float(1.23), - CastTextToRealResultCode::NotValidButPrefix, - ), - ); - assert_eq!( - text_to_real("1.23E-"), - ( - OwnedValue::Float(1.23), - CastTextToRealResultCode::NotValidButPrefix, - ), - ); - assert_eq!( - text_to_real("0"), - (OwnedValue::Float(0.0), CastTextToRealResultCode::PureInt,), - ); - assert_eq!( - text_to_real("-0"), - (OwnedValue::Float(-0.0), CastTextToRealResultCode::PureInt,), - ); - assert_eq!( - text_to_real("-0"), - (OwnedValue::Float(0.0), CastTextToRealResultCode::PureInt,), - ); - assert_eq!( - text_to_real("-0.0"), - (OwnedValue::Float(0.0), CastTextToRealResultCode::HasDecimal,), - ); - assert_eq!( - text_to_real("0.0"), - (OwnedValue::Float(0.0), CastTextToRealResultCode::HasDecimal,), - ); - assert_eq!( - text_to_real("-"), - (OwnedValue::Float(0.0), CastTextToRealResultCode::NotValid,), - ); - } } diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 445eabbf9..c9893fed7 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -3,6 +3,7 @@ use std::num::NonZero; use super::{cast_text_to_numeric, AggFunc, BranchOffset, CursorID, FuncCtx, PageIdx}; use crate::storage::wal::CheckpointMode; use crate::types::{OwnedValue, Record}; +use crate::util::RoundToPrecision; use limbo_macros::Description; macro_rules! final_agg_values { @@ -712,7 +713,7 @@ pub fn exec_add(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { } } (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => { - OwnedValue::Float((lhs + rhs).round_to_precision(6)) + OwnedValue::Float((lhs + rhs).round_to_precision(6.0)) } (OwnedValue::Float(f), OwnedValue::Integer(i)) | (OwnedValue::Integer(i), OwnedValue::Float(f)) => OwnedValue::Float(*f + *i as f64), @@ -768,7 +769,7 @@ pub fn exec_multiply(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { } } (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => { - OwnedValue::Float((lhs * rhs).round_to_precision(6)) + OwnedValue::Float((lhs * rhs).round_to_precision(6.0)) } (OwnedValue::Integer(i), OwnedValue::Float(f)) | (OwnedValue::Float(f), OwnedValue::Integer(i)) => OwnedValue::Float(*i as f64 * { *f }), @@ -1083,17 +1084,6 @@ pub fn exec_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { } } -trait RoundToPrecision { - fn round_to_precision(self, precision: i32) -> f64; -} - -impl RoundToPrecision for f64 { - fn round_to_precision(self, precision: i32) -> f64 { - let factor = 10f64.powi(precision); - (self * factor).round() / factor - } -} - #[cfg(test)] mod tests { use crate::{ diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 00ff7f258..a3339d048 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -39,12 +39,10 @@ use crate::storage::wal::CheckpointResult; use crate::storage::{btree::BTreeCursor, pager::Pager}; use crate::translate::plan::{ResultSetColumn, TableReference}; use crate::types::{ - AggContext, Cursor, CursorResult, ExternalAggState, OwnedValue, Record, SeekKey, SeekOp, -}; -use crate::util::{ - parse_schema_rows, text_to_integer, text_to_real, CastTextToIntResultCode, - CastTextToRealResultCode, + AggContext, Cursor, CursorResult, ExternalAggState, OwnedValue, OwnedValueType, Record, + SeekKey, SeekOp, }; +use crate::util::{parse_schema_rows, RoundToPrecision}; use crate::vdbe::builder::CursorType; use crate::vdbe::insn::Insn; use crate::vector::{vector32, vector64, vector_distance_cos, vector_extract}; @@ -1179,20 +1177,18 @@ impl Program { } else { conn.auto_commit.replace(*auto_commit); } + } else if !*auto_commit { + return Err(LimboError::TxError( + "cannot start a transaction within a transaction".to_string(), + )); + } else if *rollback { + return Err(LimboError::TxError( + "cannot rollback - no transaction is active".to_string(), + )); } else { - if !*auto_commit { - return Err(LimboError::TxError( - "cannot start a transaction within a transaction".to_string(), - )); - } else if *rollback { - return Err(LimboError::TxError( - "cannot rollback - no transaction is active".to_string(), - )); - } else { - return Err(LimboError::TxError( - "cannot commit - no transaction is active".to_string(), - )); - } + return Err(LimboError::TxError( + "cannot commit - no transaction is active".to_string(), + )); } return self.halt(pager); } @@ -2040,7 +2036,7 @@ impl Program { unreachable!("Cast with non-text type"); }; let result = - exec_cast(®_value_argument, ®_value_type.as_str()); + exec_cast(®_value_argument, reg_value_type.as_str()); state.registers[*dest] = result; } ScalarFunc::Changes => { @@ -2078,8 +2074,8 @@ impl Program { }; OwnedValue::Integer(exec_glob( cache, - &pattern.as_str(), - &text.as_str(), + pattern.as_str(), + text.as_str(), ) as i64) } @@ -2110,12 +2106,12 @@ impl Program { let match_expression = &state.registers[*start_reg + 1]; let pattern = match pattern { - OwnedValue::Text(_) => pattern.clone(), - _ => exec_cast(pattern, "TEXT"), + OwnedValue::Text(_) => pattern, + _ => &exec_cast(pattern, "TEXT"), }; let match_expression = match match_expression { - OwnedValue::Text(_) => match_expression.clone(), - _ => exec_cast(match_expression, "TEXT"), + OwnedValue::Text(_) => match_expression, + _ => &exec_cast(match_expression, "TEXT"), }; let result = match (pattern, match_expression) { @@ -2131,8 +2127,8 @@ impl Program { }; OwnedValue::Integer(exec_like_with_escape( - &pattern.as_str(), - &match_expression.as_str(), + pattern.as_str(), + match_expression.as_str(), escape, ) as i64) @@ -2148,14 +2144,14 @@ impl Program { }; OwnedValue::Integer(exec_like( cache, - &pattern.as_str(), - &match_expression.as_str(), + pattern.as_str(), + match_expression.as_str(), ) as i64) } - (OwnedValue::Null, OwnedValue::Null) - | (OwnedValue::Null, _) - | (_, OwnedValue::Null) => OwnedValue::Null, + (OwnedValue::Null, _) | (_, OwnedValue::Null) => { + OwnedValue::Null + } _ => { unreachable!("Like failed"); } @@ -2825,7 +2821,7 @@ impl Program { .expect("only weak ref to connection?"); let auto_commit = *connection.auto_commit.borrow(); tracing::trace!("Halt auto_commit {}", auto_commit); - return if auto_commit { + if auto_commit { let current_state = connection.transaction_state.borrow().clone(); if current_state == TransactionState::Read { pager.end_read_tx()?; @@ -2849,8 +2845,8 @@ impl Program { conn.set_changes(self.n_change.get()); } } - return Ok(StepResult::Done); - }; + Ok(StepResult::Done) + } } } @@ -3451,30 +3447,35 @@ fn exec_unicode(reg: &OwnedValue) -> OwnedValue { fn _to_float(reg: &OwnedValue) -> f64 { match reg { - OwnedValue::Text(x) => x.as_str().parse().unwrap_or(0.0), + OwnedValue::Text(x) => match cast_text_to_numeric(x.as_str()) { + OwnedValue::Integer(i) => i as f64, + OwnedValue::Float(f) => f, + _ => unreachable!(), + }, OwnedValue::Integer(x) => *x as f64, OwnedValue::Float(x) => *x, + OwnedValue::Agg(ctx) => _to_float(ctx.final_value()), _ => 0.0, } } fn exec_round(reg: &OwnedValue, precision: Option) -> OwnedValue { - let precision = match precision { - Some(OwnedValue::Text(x)) => x.as_str().parse().unwrap_or(0.0), - Some(OwnedValue::Integer(x)) => x as f64, - Some(OwnedValue::Float(x)) => x, - Some(OwnedValue::Null) => return OwnedValue::Null, - _ => 0.0, + let reg = _to_float(reg); + let round = |reg: f64, f: f64| { + let precision = if f < 1.0 { 0.0 } else { f }; + OwnedValue::Float(reg.round_to_precision(precision)) }; - - let reg = match reg { - OwnedValue::Agg(ctx) => _to_float(ctx.final_value()), - _ => _to_float(reg), - }; - - let precision = if precision < 1.0 { 0.0 } else { precision }; - let multiplier = 10f64.powi(precision as i32); - OwnedValue::Float(((reg * multiplier).round()) / multiplier) + match precision { + Some(OwnedValue::Text(x)) => match cast_text_to_numeric(x.as_str()) { + OwnedValue::Integer(i) => round(reg, i as f64), + OwnedValue::Float(f) => round(reg, f), + _ => unreachable!(), + }, + Some(OwnedValue::Integer(i)) => round(reg, i as f64), + Some(OwnedValue::Float(f)) => round(reg, f), + None => round(reg, 0.0), + _ => OwnedValue::Null, + } } // Implements TRIM pattern matching. @@ -3566,9 +3567,9 @@ fn exec_cast(value: &OwnedValue, datatype: &str) -> OwnedValue { OwnedValue::Blob(b) => { // Convert BLOB to TEXT first let text = String::from_utf8_lossy(b); - cast_text_to_real(&text).0 + cast_text_to_real(&text) } - OwnedValue::Text(t) => cast_text_to_real(t.as_str()).0, + OwnedValue::Text(t) => cast_text_to_real(t.as_str()), OwnedValue::Integer(i) => OwnedValue::Float(*i as f64), OwnedValue::Float(f) => OwnedValue::Float(*f), _ => OwnedValue::Float(0.0), @@ -3577,9 +3578,9 @@ fn exec_cast(value: &OwnedValue, datatype: &str) -> OwnedValue { OwnedValue::Blob(b) => { // Convert BLOB to TEXT first let text = String::from_utf8_lossy(b); - cast_text_to_integer(&text).0 + cast_text_to_integer(&text) } - OwnedValue::Text(t) => cast_text_to_integer(t.as_str()).0, + OwnedValue::Text(t) => cast_text_to_integer(t.as_str()), OwnedValue::Integer(i) => OwnedValue::Integer(*i), // A cast of a REAL value into an INTEGER results in the integer between the REAL value and zero // that is closest to the REAL value. If a REAL is greater than the greatest possible signed integer (+9223372036854775807) @@ -3677,14 +3678,14 @@ fn cast_text_to_integer(text: &str) -> OwnedValue { /// the TEXT value are ignored when converging from TEXT to REAL. /// If there is no prefix that can be interpreted as a real number, the result of the conversion is 0.0. fn cast_text_to_real(text: &str) -> OwnedValue { - let trimmed = text.trim_start(); + let trimmed = text.trim(); if trimmed.is_empty() { return OwnedValue::Float(0.0); } if let Ok(num) = trimmed.parse::() { return OwnedValue::Float(num); } - let Ok((_, _, text)) = parse_numeric_str(trimmed) else { + let Ok((_, text)) = parse_numeric_str(trimmed) else { return OwnedValue::Float(0.0); }; text.parse::() @@ -3705,19 +3706,19 @@ pub fn checked_cast_text_to_numeric(text: &str) -> std::result::Result Ok(text .parse::() - .map_or(OwnedValue::Integer(0), OwnedValue::Integer)) - } else { - Ok(text + .map_or(OwnedValue::Integer(0), OwnedValue::Integer)), + OwnedValueType::Float => Ok(text .parse::() - .map_or(OwnedValue::Float(0.0), OwnedValue::Float)) + .map_or(OwnedValue::Float(0.0), OwnedValue::Float)), + _ => unreachable!(), } } -fn parse_numeric_str(text: &str) -> Result<(bool, bool, &str), ()> { +fn parse_numeric_str(text: &str) -> Result<(OwnedValueType, &str), ()> { let bytes = text.trim_start().as_bytes(); let mut end = 0; let mut has_decimal = false; @@ -3746,7 +3747,14 @@ fn parse_numeric_str(text: &str) -> Result<(bool, bool, &str), ()> { if end == 0 || (end == 1 && bytes[0] == b'-') { return Err(()); } - Ok((has_decimal, has_exponent, &text[..end])) + Ok(( + if !has_decimal && !has_exponent { + OwnedValueType::Integer + } else { + OwnedValueType::Float + }, + &text[..end], + )) } fn cast_text_to_numeric(txt: &str) -> OwnedValue {