diff --git a/core/functions/datetime.rs b/core/functions/datetime.rs index a41dad656..3bc7e26d9 100644 --- a/core/functions/datetime.rs +++ b/core/functions/datetime.rs @@ -1,9 +1,11 @@ use crate::LimboError::InvalidModifier; use crate::Result; +use crate::{ends_with_ignore_ascii_case, eq_ignore_ascii_case, starts_with_ignore_ascii_case}; use crate::{types::Value, vdbe::Register}; use chrono::{ DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta, TimeZone, Timelike, Utc, }; +use turso_macros::match_ignore_ascii_case; /// Execution of date/time/datetime functions #[inline(always)] @@ -544,102 +546,123 @@ fn parse_modifier_time(s: &str) -> Result { } fn parse_modifier(modifier: &str) -> Result { - let modifier = modifier.trim().to_lowercase(); + let modifier = modifier.trim().as_bytes(); - match modifier.as_str() { + #[inline(always)] + fn from_bytes(bytes: &[u8]) -> &str { + unsafe { str::from_utf8_unchecked(bytes) } + } // safe because input is from &str + + match_ignore_ascii_case!(match modifier { // exact matches first - "ceiling" => Ok(Modifier::Ceiling), - "floor" => Ok(Modifier::Floor), - "start of month" => Ok(Modifier::StartOfMonth), - "start of year" => Ok(Modifier::StartOfYear), - "start of day" => Ok(Modifier::StartOfDay), - s if s.starts_with("weekday ") => { - let day = parse_modifier_number(&s[8..])?; - if !(0..=6).contains(&day) { - Err(InvalidModifier( - "Weekday must be between 0 and 6".to_string(), - )) - } else { - Ok(Modifier::Weekday(day as u32)) - } - } - "unixepoch" => Ok(Modifier::UnixEpoch), - "julianday" => Ok(Modifier::JulianDay), - "auto" => Ok(Modifier::Auto), - "localtime" => Ok(Modifier::Localtime), - "utc" => Ok(Modifier::Utc), - "subsec" | "subsecond" => Ok(Modifier::Subsec), - s if s.ends_with(" day") => Ok(Modifier::Days(parse_modifier_number(&s[..s.len() - 4])?)), - s if s.ends_with(" days") => Ok(Modifier::Days(parse_modifier_number(&s[..s.len() - 5])?)), - s if s.ends_with(" hour") => Ok(Modifier::Hours(parse_modifier_number(&s[..s.len() - 5])?)), - s if s.ends_with(" hours") => { - Ok(Modifier::Hours(parse_modifier_number(&s[..s.len() - 6])?)) - } - s if s.ends_with(" minute") => { - Ok(Modifier::Minutes(parse_modifier_number(&s[..s.len() - 7])?)) - } - s if s.ends_with(" minutes") => { - Ok(Modifier::Minutes(parse_modifier_number(&s[..s.len() - 8])?)) - } - s if s.ends_with(" second") => { - Ok(Modifier::Seconds(parse_modifier_number(&s[..s.len() - 7])?)) - } - s if s.ends_with(" seconds") => { - Ok(Modifier::Seconds(parse_modifier_number(&s[..s.len() - 8])?)) - } - s if s.ends_with(" month") => Ok(Modifier::Months( - parse_modifier_number(&s[..s.len() - 6])? as i32, - )), - s if s.ends_with(" months") => Ok(Modifier::Months( - parse_modifier_number(&s[..s.len() - 7])? as i32, - )), - s if s.ends_with(" year") => Ok(Modifier::Years( - parse_modifier_number(&s[..s.len() - 5])? as i32 - )), - s if s.ends_with(" years") => Ok(Modifier::Years( - parse_modifier_number(&s[..s.len() - 6])? as i32, - )), - s if s.starts_with('+') || s.starts_with('-') => { - let sign = if s.starts_with('-') { -1 } else { 1 }; - let parts: Vec<&str> = s[1..].split(' ').collect(); - let digits_in_date = 10; - match parts.len() { - 1 => { - if parts[0].len() == digits_in_date { - let date = parse_modifier_date(parts[0])?; - Ok(Modifier::DateOffset { - years: sign * date.year(), - months: sign * date.month() as i32, - days: sign * date.day() as i32, - }) + b"ceiling" => Ok(Modifier::Ceiling), + b"floor" => Ok(Modifier::Floor), + b"start of month" => Ok(Modifier::StartOfMonth), + b"start of year" => Ok(Modifier::StartOfYear), + b"start of day" => Ok(Modifier::StartOfDay), + b"unixepoch" => Ok(Modifier::UnixEpoch), + b"julianday" => Ok(Modifier::JulianDay), + b"auto" => Ok(Modifier::Auto), + b"localtime" => Ok(Modifier::Localtime), + b"utc" => Ok(Modifier::Utc), + b"subsec" | b"subsecond" => Ok(Modifier::Subsec), + _ => { + match modifier { + s if starts_with_ignore_ascii_case!(s, b"weekday ") => { + let day = parse_modifier_number(from_bytes(&s[8..]))?; + if !(0..=6).contains(&day) { + Err(InvalidModifier( + "Weekday must be between 0 and 6".to_string(), + )) } else { - // time values are either 12, 8 or 5 digits - let time = parse_modifier_time(parts[0])?; - let time_delta = sign * (time.num_seconds_from_midnight() as i32); - Ok(Modifier::TimeOffset(TimeDelta::seconds(time_delta.into()))) + Ok(Modifier::Weekday(day as u32)) } } - 2 => { - let date = parse_modifier_date(parts[0])?; - let time = parse_modifier_time(parts[1])?; - // Convert time to total seconds (with sign) - let time_delta = sign * (time.num_seconds_from_midnight() as i32); - Ok(Modifier::DateTimeOffset { - years: sign * (date.year()), - months: sign * (date.month() as i32), - days: sign * date.day() as i32, - seconds: time_delta, - }) + s if ends_with_ignore_ascii_case!(s, b" day") => Ok(Modifier::Days( + parse_modifier_number(from_bytes(&s[..s.len() - 4]))?, + )), + s if ends_with_ignore_ascii_case!(s, b" days") => Ok(Modifier::Days( + parse_modifier_number(from_bytes(&s[..s.len() - 5]))?, + )), + s if ends_with_ignore_ascii_case!(s, b" hour") => Ok(Modifier::Hours( + parse_modifier_number(from_bytes(&s[..s.len() - 5]))?, + )), + s if ends_with_ignore_ascii_case!(s, b" hours") => Ok(Modifier::Hours( + parse_modifier_number(from_bytes(&s[..s.len() - 6]))?, + )), + s if ends_with_ignore_ascii_case!(s, b" minute") => Ok(Modifier::Minutes( + parse_modifier_number(from_bytes(&s[..s.len() - 7]))?, + )), + s if ends_with_ignore_ascii_case!(s, b" minutes") => Ok(Modifier::Minutes( + parse_modifier_number(from_bytes(&s[..s.len() - 8]))?, + )), + s if ends_with_ignore_ascii_case!(s, b" second") => Ok(Modifier::Seconds( + parse_modifier_number(from_bytes(&s[..s.len() - 7]))?, + )), + s if ends_with_ignore_ascii_case!(s, b" seconds") => Ok(Modifier::Seconds( + parse_modifier_number(from_bytes(&s[..s.len() - 8]))?, + )), + s if ends_with_ignore_ascii_case!(s, b" month") => Ok(Modifier::Months( + parse_modifier_number(from_bytes(&s[..s.len() - 6]))? as i32, + )), + s if ends_with_ignore_ascii_case!(s, b" months") => Ok(Modifier::Months( + parse_modifier_number(from_bytes(&s[..s.len() - 7]))? as i32, + )), + s if ends_with_ignore_ascii_case!(s, b" year") => Ok(Modifier::Years( + parse_modifier_number(from_bytes(&s[..s.len() - 5]))? as i32, + )), + s if ends_with_ignore_ascii_case!(s, b" years") => Ok(Modifier::Years( + parse_modifier_number(from_bytes(&s[..s.len() - 6]))? as i32, + )), + s if starts_with_ignore_ascii_case!(s, b"+") + || starts_with_ignore_ascii_case!(s, b"-") => + { + let sign = if starts_with_ignore_ascii_case!(s, b"-") { + -1 + } else { + 1 + }; + let parts: Vec<&str> = from_bytes(&s[1..]).split(' ').collect(); + let digits_in_date = 10; + match parts.len() { + 1 => { + if parts[0].len() == digits_in_date { + let date = parse_modifier_date(parts[0])?; + Ok(Modifier::DateOffset { + years: sign * date.year(), + months: sign * date.month() as i32, + days: sign * date.day() as i32, + }) + } else { + // time values are either 12, 8 or 5 digits + let time = parse_modifier_time(parts[0])?; + let time_delta = sign * (time.num_seconds_from_midnight() as i32); + Ok(Modifier::TimeOffset(TimeDelta::seconds(time_delta.into()))) + } + } + 2 => { + let date = parse_modifier_date(parts[0])?; + let time = parse_modifier_time(parts[1])?; + // Convert time to total seconds (with sign) + let time_delta = sign * (time.num_seconds_from_midnight() as i32); + Ok(Modifier::DateTimeOffset { + years: sign * (date.year()), + months: sign * (date.month() as i32), + days: sign * date.day() as i32, + seconds: time_delta, + }) + } + _ => Err(InvalidModifier( + "Invalid date/time offset format".to_string(), + )), + } } _ => Err(InvalidModifier( "Invalid date/time offset format".to_string(), )), } } - _ => Err(InvalidModifier( - "Invalid date/time offset format".to_string(), - )), - } + }) } pub fn exec_timediff(values: &[Register]) -> Value { diff --git a/core/incremental/operator.rs b/core/incremental/operator.rs index 85f7e640c..4f148e943 100644 --- a/core/incremental/operator.rs +++ b/core/incremental/operator.rs @@ -10,6 +10,7 @@ use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display}; use std::sync::Arc; use std::sync::Mutex; +use turso_macros::match_ignore_ascii_case; /// Tracks computation counts to verify incremental behavior (for tests now), and in the future /// should be used to provide statistics. @@ -936,8 +937,9 @@ impl ProjectOperator { } } Expr::FunctionCall { name, args, .. } => { - match name.as_str().to_lowercase().as_str() { - "hex" => { + let name_bytes = name.as_str().as_bytes(); + match_ignore_ascii_case!(match name_bytes { + b"hex" => { if args.len() == 1 { let arg_val = self.evaluate_expression(&args[0], values); match arg_val { @@ -949,7 +951,7 @@ impl ProjectOperator { } } _ => Value::Null, // Other functions not supported yet - } + }) } Expr::Parenthesized(inner) => { assert!( diff --git a/core/schema.rs b/core/schema.rs index a4f129645..b2fa66b31 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -20,7 +20,9 @@ use crate::result::LimboResult; use crate::storage::btree::BTreeCursor; use crate::translate::collate::CollationSeq; use crate::translate::plan::SelectPlan; -use crate::util::{module_args_from_sql, module_name_from_sql, IOExt, UnparsedFromSqlIndex}; +use crate::util::{ + module_args_from_sql, module_name_from_sql, type_from_name, IOExt, UnparsedFromSqlIndex, +}; use crate::{return_if_io, LimboError, MvCursor, Pager, RefValue, SymbolTable, VirtualTable}; use crate::{util::normalize_ident, Result}; use core::fmt; @@ -1103,25 +1105,7 @@ impl From for Column { let ty = match value.col_type { Some(ref data_type) => { // https://www.sqlite.org/datatype3.html - let type_name = data_type.name.clone().to_uppercase(); - - if type_name.contains("INT") { - Type::Integer - } else if type_name.contains("CHAR") - || type_name.contains("CLOB") - || type_name.contains("TEXT") - { - Type::Text - } else if type_name.contains("BLOB") || type_name.is_empty() { - Type::Blob - } else if type_name.contains("REAL") - || type_name.contains("FLOA") - || type_name.contains("DOUB") - { - Type::Real - } else { - Type::Numeric - } + type_from_name(&data_type.name) } None => Type::Null, }; diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index 51cf84ee5..f5715992e 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -6,6 +6,7 @@ use aes_gcm::{ Aes256Gcm, Key, Nonce, }; use std::ops::Deref; +use turso_macros::match_ignore_ascii_case; pub const ENCRYPTED_PAGE_SIZE: usize = 4096; @@ -138,13 +139,14 @@ impl TryFrom<&str> for CipherMode { type Error = LimboError; fn try_from(s: &str) -> Result { - match s.to_lowercase().as_str() { - "aes256gcm" | "aes-256-gcm" | "aes_256_gcm" => Ok(CipherMode::Aes256Gcm), - "aegis256" | "aegis-256" | "aegis_256" => Ok(CipherMode::Aegis256), + let s_bytes = s.as_bytes(); + match_ignore_ascii_case!(match s_bytes { + b"aes256gcm" | b"aes-256-gcm" | b"aes_256_gcm" => Ok(CipherMode::Aes256Gcm), + b"aegis256" | b"aegis-256" | b"aegis_256" => Ok(CipherMode::Aegis256), _ => Err(LimboError::InvalidArgument(format!( "Unknown cipher name: {s}" ))), - } + }) } } diff --git a/core/translate/logical.rs b/core/translate/logical.rs index df8bdd13a..9c970f800 100644 --- a/core/translate/logical.rs +++ b/core/translate/logical.rs @@ -14,6 +14,7 @@ use crate::{LimboError, Result}; use std::collections::HashMap; use std::fmt::{self, Display, Formatter}; use std::sync::Arc; +use turso_macros::match_ignore_ascii_case; use turso_parser::ast; /// Result type for preprocessing aggregate expressions @@ -1400,19 +1401,20 @@ impl<'a> LogicalPlanBuilder<'a> { /// Parse aggregate function name (considering argument count for min/max) fn parse_aggregate_function(name: &str, arg_count: usize) -> Option { - match name.to_uppercase().as_str() { - "COUNT" => Some(AggFunc::Count), - "SUM" => Some(AggFunc::Sum), - "AVG" => Some(AggFunc::Avg), + let name_bytes = name.as_bytes(); + match_ignore_ascii_case!(match name_bytes { + b"COUNT" => Some(AggFunc::Count), + b"SUM" => Some(AggFunc::Sum), + b"AVG" => Some(AggFunc::Avg), // MIN and MAX are only aggregates with 1 argument // With 2+ arguments, they're scalar functions - "MIN" if arg_count == 1 => Some(AggFunc::Min), - "MAX" if arg_count == 1 => Some(AggFunc::Max), - "GROUP_CONCAT" => Some(AggFunc::GroupConcat), - "STRING_AGG" => Some(AggFunc::StringAgg), - "TOTAL" => Some(AggFunc::Total), + b"MIN" if arg_count == 1 => Some(AggFunc::Min), + b"MAX" if arg_count == 1 => Some(AggFunc::Max), + b"GROUP_CONCAT" => Some(AggFunc::GroupConcat), + b"STRING_AGG" => Some(AggFunc::StringAgg), + b"TOTAL" => Some(AggFunc::Total), _ => None, - } + }) } // Check if expression contains aggregates diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index b04d0e87e..ff7c2ccf2 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -4,6 +4,7 @@ use chrono::Datelike; use std::rc::Rc; use std::sync::Arc; +use turso_macros::match_ignore_ascii_case; use turso_parser::ast::{self, ColumnDefinition, Expr, Literal, Name}; use turso_parser::ast::{PragmaName, QualifiedName}; @@ -213,17 +214,17 @@ fn update_pragma( PragmaName::AutoVacuum => { let auto_vacuum_mode = match value { Expr::Name(name) => { - let name = name.as_str().to_lowercase(); - match name.as_str() { - "none" => 0, - "full" => 1, - "incremental" => 2, + let name = name.as_str().as_bytes(); + match_ignore_ascii_case!(match name { + b"none" => 0, + b"full" => 1, + b"incremental" => 2, _ => { return Err(LimboError::InvalidArgument( "invalid auto vacuum mode".to_string(), )); } - } + }) } _ => { return Err(LimboError::InvalidArgument( @@ -329,11 +330,11 @@ fn update_pragma( let mode = match value { Expr::Name(name) => { - let name_upper = name.as_str().to_uppercase(); - match name_upper.as_str() { - "OFF" | "FALSE" | "NO" | "0" => SyncMode::Off, + let name_bytes = name.as_str().as_bytes(); + match_ignore_ascii_case!(match name_bytes { + b"OFF" | b"FALSE" | b"NO" | b"0" => SyncMode::Off, _ => SyncMode::Full, - } + }) } Expr::Literal(Literal::Numeric(n)) => match n.as_str() { "0" => SyncMode::Off, @@ -573,8 +574,11 @@ fn query_pragma( ast::Expr::Literal(Literal::Numeric(i)) => i.parse::().unwrap() != 0, ast::Expr::Literal(Literal::String(ref s)) | ast::Expr::Name(Name::Ident(ref s)) => { - let s = s.to_lowercase(); - s == "1" || s == "on" || s == "true" + let s = s.as_bytes(); + match_ignore_ascii_case!(match s { + b"1" | b"on" | b"true" => true, + _ => false, + }) } _ => { return Err(LimboError::ParseError(format!( diff --git a/core/util.rs b/core/util.rs index 7d22106a6..02d7a63d2 100644 --- a/core/util.rs +++ b/core/util.rs @@ -13,6 +13,7 @@ use std::{ sync::{Arc, Mutex}, }; use tracing::{instrument, Level}; +use turso_macros::match_ignore_ascii_case; use turso_parser::ast::{ self, fmt::ToTokens, Cmd, CreateTableBody, Expr, FunctionTail, Literal, Stmt, UnaryOperator, }; @@ -31,6 +32,58 @@ macro_rules! io_yield_many { }; } +#[macro_export] +macro_rules! eq_ignore_ascii_case { + ( $var:expr, $value:literal ) => {{ + match_ignore_ascii_case!(match $var { + $value => true, + _ => false, + }) + }}; +} + +#[macro_export] +macro_rules! contains_ignore_ascii_case { + ( $var:expr, $value:literal ) => {{ + let compare_to_idx = $var.len().saturating_sub($value.len()); + if $var.len() < $value.len() { + false + } else { + let mut result = false; + for i in 0..=compare_to_idx { + if eq_ignore_ascii_case!(&$var[i..i + $value.len()], $value) { + result = true; + break; + } + } + + result + } + }}; +} + +#[macro_export] +macro_rules! starts_with_ignore_ascii_case { + ( $var:expr, $value:literal ) => {{ + if $var.len() < $value.len() { + false + } else { + eq_ignore_ascii_case!(&$var[..$value.len()], $value) + } + }}; +} + +#[macro_export] +macro_rules! ends_with_ignore_ascii_case { + ( $var:expr, $value:literal ) => {{ + if $var.len() < $value.len() { + false + } else { + eq_ignore_ascii_case!(&$var[$var.len() - $value.len()..], $value) + } + }}; +} + pub trait IOExt { fn block(&self, f: impl FnMut() -> Result>) -> Result; } @@ -112,7 +165,10 @@ pub fn parse_schema_rows( "table" => { let root_page: i64 = row.get::(3)?; let sql: &str = row.get::<&str>(4)?; - if root_page == 0 && sql.to_lowercase().contains("create virtual") { + let sql_bytes = sql.as_bytes(); + if root_page == 0 + && contains_ignore_ascii_case!(sql_bytes, b"create virtual") + { let name: &str = row.get::<&str>(1)?; // a virtual table is found in the sqlite_schema, but it's no // longer in the in-memory schema. We need to recreate it if @@ -609,6 +665,27 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool { } } +pub(crate) fn type_from_name(type_name: &str) -> Type { + let type_name = type_name.as_bytes(); + if contains_ignore_ascii_case!(type_name, b"INT") { + Type::Integer + } else if contains_ignore_ascii_case!(type_name, b"CHAR") + || contains_ignore_ascii_case!(type_name, b"CLOB") + || contains_ignore_ascii_case!(type_name, b"TEXT") + { + Type::Text + } else if contains_ignore_ascii_case!(type_name, b"BLOB") || type_name.is_empty() { + Type::Blob + } else if contains_ignore_ascii_case!(type_name, b"REAL") + || contains_ignore_ascii_case!(type_name, b"FLOA") + || contains_ignore_ascii_case!(type_name, b"DOUB") + { + Type::Real + } else { + Type::Numeric + } +} + pub fn columns_from_create_table_body( body: &turso_parser::ast::CreateTableBody, ) -> crate::Result> { @@ -633,24 +710,7 @@ pub fn columns_from_create_table_body( ty: match col_type { Some(ref data_type) => { // https://www.sqlite.org/datatype3.html - let type_name = data_type.name.as_str().to_uppercase(); - if type_name.contains("INT") { - Type::Integer - } else if type_name.contains("CHAR") - || type_name.contains("CLOB") - || type_name.contains("TEXT") - { - Type::Text - } else if type_name.contains("BLOB") || type_name.is_empty() { - Type::Blob - } else if type_name.contains("REAL") - || type_name.contains("FLOA") - || type_name.contains("DOUB") - { - Type::Real - } else { - Type::Numeric - } + type_from_name(data_type.name.as_str()) } None => Type::Null, }, @@ -772,15 +832,16 @@ impl From<&str> for CacheMode { impl OpenMode { pub fn from_str(s: &str) -> Result { - match s.trim().to_lowercase().as_str() { - "ro" => Ok(OpenMode::ReadOnly), - "rw" => Ok(OpenMode::ReadWrite), - "memory" => Ok(OpenMode::Memory), - "rwc" => Ok(OpenMode::ReadWriteCreate), + let s_bytes = s.trim().as_bytes(); + match_ignore_ascii_case!(match s_bytes { + b"ro" => Ok(OpenMode::ReadOnly), + b"rw" => Ok(OpenMode::ReadWrite), + b"memory" => Ok(OpenMode::Memory), + b"rwc" => Ok(OpenMode::ReadWriteCreate), _ => Err(LimboError::InvalidArgument(format!( "Invalid mode: '{s}'. Expected one of 'ro', 'rw', 'memory', 'rwc'" ))), - } + }) } } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index f26ecac82..36aabeedd 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -39,6 +39,7 @@ use std::{ rc::Rc, sync::{Arc, Mutex}, }; +use turso_macros::match_ignore_ascii_case; use crate::{pseudo::PseudoCursor, result::LimboResult}; @@ -8808,11 +8809,11 @@ pub fn op_journal_mode( // Currently, Turso only supports WAL mode // If a new mode is specified, we validate it but always return "wal" if let Some(mode) = new_mode { - let mode_lower = mode.to_lowercase(); + let mode_bytes = mode.as_bytes(); // Valid journal modes in SQLite are: delete, truncate, persist, memory, wal, off // We accept any valid mode but always use WAL - match mode_lower.as_str() { - "delete" | "truncate" | "persist" | "memory" | "wal" | "off" => { + match_ignore_ascii_case!(match mode_bytes { + b"delete" | b"truncate" | b"persist" | b"memory" | b"wal" | b"off" => { // Mode is valid, but we stay in WAL mode } _ => { @@ -8821,7 +8822,7 @@ pub fn op_journal_mode( "Unknown journal mode: {mode}" ))); } - } + }) } // Always return "wal" as the current journal mode diff --git a/macros/src/ext/match_ignore_ascii_case.rs b/macros/src/ext/match_ignore_ascii_case.rs index 85e62c0a2..ab06482e7 100644 --- a/macros/src/ext/match_ignore_ascii_case.rs +++ b/macros/src/ext/match_ignore_ascii_case.rs @@ -112,8 +112,14 @@ pub fn match_ignore_ascci_case(input: TokenStream) -> TokenStream { entry: &PathEntry, ) -> proc_macro2::TokenStream { let eof_handle = if let Some(ref result) = entry.result { + let guard = if let Some(ref b) = result.guard { + let expr = &b.1; + quote! { if #expr } + } else { + quote! {} + }; let body = &result.body; - quote! { None => { #body } } + quote! { None #guard => { #body } } } else { quote! {} };