From 8070e51e26331fe920f8538fd80f774941b3b0d0 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 23 Feb 2025 12:53:36 -0500 Subject: [PATCH] Fix vdbe casting and rounding issues --- core/vdbe/insn.rs | 181 ++++++++++++---------------------------------- core/vdbe/mod.rs | 142 ++++++++++++++++++++++++------------ 2 files changed, 144 insertions(+), 179 deletions(-) diff --git a/core/vdbe/insn.rs b/core/vdbe/insn.rs index 079d09023..445eabbf9 100644 --- a/core/vdbe/insn.rs +++ b/core/vdbe/insn.rs @@ -5,6 +5,18 @@ use crate::storage::wal::CheckpointMode; use crate::types::{OwnedValue, Record}; use limbo_macros::Description; +macro_rules! final_agg_values { + ($var:ident) => { + if let OwnedValue::Agg(agg) = $var { + $var = agg.final_value(); + } + }; + ($lhs:ident, $rhs:ident) => { + final_agg_values!($lhs); + final_agg_values!($rhs); + }; +} + /// Flags provided to comparison instructions (e.g. Eq, Ne) which determine behavior related to NULL values. #[derive(Clone, Copy, Debug, Default)] pub struct CmpInsFlags(usize); @@ -688,42 +700,8 @@ pub enum Cookie { UserVersion = 6, } -fn cast_text_to_numerical(value: &str) -> OwnedValue { - if let Ok(x) = value.parse::() { - OwnedValue::Integer(x) - } else if let Ok(x) = value.parse::() { - OwnedValue::Float(x) - } else { - OwnedValue::Integer(0) - } -} - -fn cast_text_to_numerical(value: &str) -> OwnedValue { - if let Ok(x) = value.parse::() { - OwnedValue::Integer(x) - } else if let Ok(x) = value.parse::() { - OwnedValue::Float(x) - } else { - let idx = value - .chars() - .enumerate() - .find_map(|(i, c)| match i { - i if i == 0 && c == '-' => None, - i if i > 0 && !c.is_ascii_digit() => Some(i), - _ => None, - }) - .unwrap_or(0); - OwnedValue::Integer(value[0..idx].parse::().unwrap_or(0)) - } -} - pub fn exec_add(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = lhs { - lhs = agg.final_value(); - } - if let OwnedValue::Agg(agg) = rhs { - rhs = agg.final_value(); - } + final_agg_values!(lhs, rhs); match (lhs, rhs) { (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { let result = lhs.overflowing_add(*rhs); @@ -733,7 +711,9 @@ pub fn exec_add(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { OwnedValue::Integer(result.0) } } - (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs + rhs), + (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => { + OwnedValue::Float((lhs + rhs).round_to_precision(6)) + } (OwnedValue::Float(f), OwnedValue::Integer(i)) | (OwnedValue::Integer(i), OwnedValue::Float(f)) => OwnedValue::Float(*f + *i as f64), (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, @@ -749,12 +729,7 @@ pub fn exec_add(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { } pub fn exec_subtract(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = lhs { - lhs = agg.final_value(); - } - if let OwnedValue::Agg(agg) = rhs { - rhs = agg.final_value(); - } + final_agg_values!(lhs, rhs); match (lhs, rhs) { (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { let result = lhs.overflowing_sub(*rhs); @@ -782,12 +757,7 @@ pub fn exec_subtract(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { } } pub fn exec_multiply(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = lhs { - lhs = agg.final_value(); - } - if let OwnedValue::Agg(agg) = rhs { - rhs = agg.final_value(); - } + final_agg_values!(lhs, rhs); match (lhs, rhs) { (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { let result = lhs.overflowing_mul(*rhs); @@ -797,7 +767,9 @@ pub fn exec_multiply(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { OwnedValue::Integer(result.0) } } - (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs * rhs), + (OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => { + OwnedValue::Float((lhs * rhs).round_to_precision(6)) + } (OwnedValue::Integer(i), OwnedValue::Float(f)) | (OwnedValue::Float(f), OwnedValue::Integer(i)) => OwnedValue::Float(*i as f64 * { *f }), (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, @@ -814,12 +786,7 @@ pub fn exec_multiply(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { } pub fn exec_divide(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = lhs { - lhs = agg.final_value(); - } - if let OwnedValue::Agg(agg) = rhs { - rhs = agg.final_value(); - } + final_agg_values!(lhs, rhs); match (lhs, rhs) { (_, OwnedValue::Integer(0)) | (_, OwnedValue::Float(0.0)) => OwnedValue::Null, (OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => { @@ -845,12 +812,7 @@ pub fn exec_divide(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { } pub fn exec_bit_and(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = lhs { - lhs = agg.final_value(); - } - if let OwnedValue::Agg(agg) = rhs { - rhs = agg.final_value(); - } + final_agg_values!(lhs, rhs); match (lhs, rhs) { (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, (_, OwnedValue::Integer(0)) @@ -875,12 +837,7 @@ pub fn exec_bit_and(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { } pub fn exec_bit_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = lhs { - lhs = agg.final_value(); - } - if let OwnedValue::Agg(agg) = rhs { - rhs = agg.final_value(); - } + final_agg_values!(lhs, rhs); match (lhs, rhs) { (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, (OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => OwnedValue::Integer(lh | rh), @@ -901,12 +858,7 @@ pub fn exec_bit_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { } pub fn exec_remainder(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = lhs { - lhs = agg.final_value(); - } - if let OwnedValue::Agg(agg) = rhs { - rhs = agg.final_value(); - } + final_agg_values!(lhs, rhs); match (lhs, rhs) { (OwnedValue::Null, _) | (_, OwnedValue::Null) @@ -947,9 +899,7 @@ pub fn exec_remainder(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue } pub fn exec_bit_not(mut reg: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = reg { - reg = agg.final_value(); - } + final_agg_values!(reg); match reg { OwnedValue::Null => OwnedValue::Null, OwnedValue::Integer(i) => OwnedValue::Integer(!i), @@ -960,12 +910,7 @@ pub fn exec_bit_not(mut reg: &OwnedValue) -> OwnedValue { } pub fn exec_shift_left(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = lhs { - lhs = agg.final_value(); - } - if let OwnedValue::Agg(agg) = rhs { - rhs = agg.final_value(); - } + final_agg_values!(lhs, rhs); match (lhs, rhs) { (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, (OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => { @@ -999,12 +944,7 @@ fn compute_shl(lhs: i64, rhs: i64) -> i64 { } pub fn exec_shift_right(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = lhs { - lhs = agg.final_value(); - } - if let OwnedValue::Agg(agg) = rhs { - rhs = agg.final_value(); - } + final_agg_values!(lhs, rhs); match (lhs, rhs) { (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, (OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => { @@ -1051,9 +991,7 @@ fn compute_shr(lhs: i64, rhs: i64) -> i64 { } pub fn exec_boolean_not(mut reg: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = reg { - reg = agg.final_value(); - } + final_agg_values!(reg); match reg { OwnedValue::Null => OwnedValue::Null, OwnedValue::Integer(i) => OwnedValue::Integer((*i == 0) as i64), @@ -1062,8 +1000,8 @@ pub fn exec_boolean_not(mut reg: &OwnedValue) -> OwnedValue { _ => todo!(), } } - -pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { +pub fn exec_concat(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { + final_agg_values!(lhs, rhs); match (lhs, rhs) { (OwnedValue::Text(lhs_text), OwnedValue::Text(rhs_text)) => { OwnedValue::build_text(&(lhs_text.as_str().to_string() + rhs_text.as_str())) @@ -1074,10 +1012,6 @@ pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { (OwnedValue::Text(lhs_text), OwnedValue::Float(rhs_float)) => { OwnedValue::build_text(&(lhs_text.as_str().to_string() + &rhs_float.to_string())) } - (OwnedValue::Text(lhs_text), OwnedValue::Agg(rhs_agg)) => OwnedValue::build_text( - (lhs_text.as_str().to_string() + &rhs_agg.final_value().to_string()).as_str(), - ), - (OwnedValue::Integer(lhs_int), OwnedValue::Text(rhs_text)) => { OwnedValue::build_text(&(lhs_int.to_string() + rhs_text.as_str())) } @@ -1087,10 +1021,6 @@ pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { (OwnedValue::Integer(lhs_int), OwnedValue::Float(rhs_float)) => { OwnedValue::build_text(&(lhs_int.to_string() + &rhs_float.to_string())) } - (OwnedValue::Integer(lhs_int), OwnedValue::Agg(rhs_agg)) => { - OwnedValue::build_text(&(lhs_int.to_string() + &rhs_agg.final_value().to_string())) - } - (OwnedValue::Float(lhs_float), OwnedValue::Text(rhs_text)) => { OwnedValue::build_text(&(lhs_float.to_string() + rhs_text.as_str())) } @@ -1100,39 +1030,19 @@ pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue { (OwnedValue::Float(lhs_float), OwnedValue::Float(rhs_float)) => { OwnedValue::build_text(&(lhs_float.to_string() + &rhs_float.to_string())) } - (OwnedValue::Float(lhs_float), OwnedValue::Agg(rhs_agg)) => { - OwnedValue::build_text(&(lhs_float.to_string() + &rhs_agg.final_value().to_string())) - } - - (OwnedValue::Agg(lhs_agg), OwnedValue::Text(rhs_text)) => { - OwnedValue::build_text(&(lhs_agg.final_value().to_string() + rhs_text.as_str())) - } - (OwnedValue::Agg(lhs_agg), OwnedValue::Integer(rhs_int)) => { - OwnedValue::build_text(&(lhs_agg.final_value().to_string() + &rhs_int.to_string())) - } - (OwnedValue::Agg(lhs_agg), OwnedValue::Float(rhs_float)) => { - OwnedValue::build_text(&(lhs_agg.final_value().to_string() + &rhs_float.to_string())) - } - (OwnedValue::Agg(lhs_agg), OwnedValue::Agg(rhs_agg)) => OwnedValue::build_text( - &(lhs_agg.final_value().to_string() + &rhs_agg.final_value().to_string()), - ), - (OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null, (OwnedValue::Blob(_), _) | (_, OwnedValue::Blob(_)) => { todo!("TODO: Handle Blob conversion to String") } - (OwnedValue::Record(_), _) | (_, OwnedValue::Record(_)) => unreachable!(), + (OwnedValue::Record(_), _) + | (_, OwnedValue::Record(_)) + | (OwnedValue::Agg(_), _) + | (_, OwnedValue::Agg(_)) => unreachable!(), } } pub fn exec_and(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = lhs { - lhs = agg.final_value(); - } - if let OwnedValue::Agg(agg) = rhs { - rhs = agg.final_value(); - } - + final_agg_values!(lhs, rhs); match (lhs, rhs) { (_, OwnedValue::Integer(0)) | (OwnedValue::Integer(0), _) @@ -1151,13 +1061,7 @@ pub fn exec_and(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { } pub fn exec_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { - if let OwnedValue::Agg(agg) = lhs { - lhs = agg.final_value(); - } - if let OwnedValue::Agg(agg) = rhs { - rhs = agg.final_value(); - } - + final_agg_values!(lhs, rhs); match (lhs, rhs) { (OwnedValue::Null, OwnedValue::Null) | (OwnedValue::Null, OwnedValue::Float(0.0)) @@ -1179,6 +1083,17 @@ pub fn exec_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue { } } +trait RoundToPrecision { + fn round_to_precision(self, precision: i32) -> f64; +} + +impl RoundToPrecision for f64 { + fn round_to_precision(self, precision: i32) -> f64 { + let factor = 10f64.powi(precision); + (self * factor).round() / factor + } +} + #[cfg(test)] mod tests { use crate::{ diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index b251caf72..5da0cd8e4 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -2592,10 +2592,13 @@ impl Program { ), }, OwnedValue::Text(text) => { - match checked_cast_text_to_numeric(&text.as_str()) { + match checked_cast_text_to_numeric(text.as_str()) { Ok(OwnedValue::Integer(i)) => { state.registers[*reg] = OwnedValue::Integer(i) } + Ok(OwnedValue::Float(f)) => { + state.registers[*reg] = OwnedValue::Integer(f as i64) + } _ => crate::bail_parse_error!( "MustBeInt: the value in register cannot be cast to integer" ), @@ -3656,24 +3659,63 @@ fn cast_text_to_integer(text: &str) -> OwnedValue { if let Ok(i) = text.parse::() { return OwnedValue::Integer(i); } - let idx = text - .chars() - .enumerate() - .find_map(|(i, c)| match i { - i if i == 0 && c == '-' => None, - i if i > 0 && !c.is_ascii_digit() => Some(i), - _ => None, - }) - .unwrap_or(0); - OwnedValue::Integer(text[0..idx].parse::().unwrap_or(0)) + let bytes = text.as_bytes(); + let mut end = 0; + if bytes[0] == b'-' { + end = 1; + } + while end < bytes.len() && bytes[end].is_ascii_digit() { + end += 1; + } + text[..end] + .parse::() + .map_or(OwnedValue::Integer(0), OwnedValue::Integer) } /// When casting a TEXT value to REAL, the longest possible prefix of the value that can be interpreted /// as a real number is extracted from the TEXT value and the remainder ignored. Any leading spaces in /// the TEXT value are ignored when converging from TEXT to REAL. /// If there is no prefix that can be interpreted as a real number, the result of the conversion is 0.0. -fn cast_text_to_real(text: &str) -> (OwnedValue, CastTextToRealResultCode) { - text_to_real(text) +fn cast_text_to_real(text: &str) -> OwnedValue { + let trimmed = text.trim_start(); + if trimmed.is_empty() { + return OwnedValue::Float(0.0); + } + if let Ok(num) = trimmed.parse::() { + return OwnedValue::Float(num); + } + let bytes = text.as_bytes(); + let mut end = 0; + let mut has_decimal = false; + let mut has_exponent = false; + + if bytes[0] == b'-' { + end = 1; + } + while end < bytes.len() { + match bytes[end] { + b'0'..=b'9' => end += 1, + b'.' if !has_decimal && !has_exponent => { + has_decimal = true; + end += 1; + } + b'e' | b'E' if !has_exponent => { + has_exponent = true; + end += 1; + // allow exponent sign + if end < bytes.len() && (bytes[end] == b'+' || bytes[end] == b'-') { + end += 1; + } + } + _ => break, + } + } + if end == 0 || (end == 1 && bytes[0] == b'-') { + return OwnedValue::Float(0.0); + } + text[..end] + .parse::() + .map_or(OwnedValue::Float(0.0), OwnedValue::Float) } /// NUMERIC Casting a TEXT or BLOB value into NUMERIC yields either an INTEGER or a REAL result. @@ -3686,44 +3728,52 @@ fn cast_text_to_real(text: &str) -> (OwnedValue, CastTextToRealResultCode) { /// IEEE 754 64-bit float and thus provides a 1-bit of margin for the text-to-float conversion operation.) /// Any text input that describes a value outside the range of a 64-bit signed integer yields a REAL result. /// Casting a REAL or INTEGER value to NUMERIC is a no-op, even if a real value could be losslessly converted to an integer. -fn checked_cast_text_to_numeric(text: &str) -> std::result::Result { - if !text.contains('.') && !text.contains('e') && !text.contains('E') { - // Looks like an integer - if let Ok(i) = text.parse::() { - return Ok(OwnedValue::Integer(i)); +pub fn checked_cast_text_to_numeric(text: &str) -> std::result::Result { + // sqlite will parse the first N digits of a string to numeric value, then determine + // whether _that_ value is more likely a real or integer value. e.g. + // '-100234-2344.23e14' evaluates to -100234 instead of -100234.0 + let bytes = text.as_bytes(); + let mut end = 0; + let mut has_decimal = false; + let mut has_exponent = false; + if bytes[0] == b'-' { + end = 1; + } + while end < bytes.len() { + match bytes[end] { + b'0'..=b'9' => end += 1, + b'.' if !has_decimal && !has_exponent => { + has_decimal = true; + end += 1; + } + b'e' | b'E' if !has_exponent => { + has_exponent = true; + end += 1; + // allow exponent sign + if end < bytes.len() && (bytes[end] == b'+' || bytes[end] == b'-') { + end += 1; + } + } + _ => break, } } - // Try as float - if let Ok(f) = text.parse::() { - return match cast_real_to_integer(f) { - Ok(i) => Ok(OwnedValue::Integer(i)), - Err(_) => Ok(OwnedValue::Float(f)), - }; + if end == 0 || (end == 1 && bytes[0] == b'-') { + return Err(()); + } + let text = &text[..end]; + if !has_decimal && !has_exponent { + Ok(text + .parse::() + .map_or(OwnedValue::Integer(0), OwnedValue::Integer)) + } else { + Ok(text + .parse::() + .map_or(OwnedValue::Float(0.0), OwnedValue::Float)) } - Err(()) } -/// Reference for function definition -/// https://github.com/sqlite/sqlite/blob/eb3a069fc82e53a40ea63076d66ab113a3b2b0c6/src/vdbe.c#L465 -fn cast_text_to_numeric(text: &str) -> OwnedValue { - let (real_cast, rc_real) = cast_text_to_real(text); - let (int_cast, rc_int) = cast_text_to_integer(text); - match (rc_real, rc_int) { - ( - CastTextToRealResultCode::NotValid, - CastTextToIntResultCode::ExcessSpace - | CastTextToIntResultCode::Success - | CastTextToIntResultCode::NotInt, - ) => int_cast, - ( - CastTextToRealResultCode::NotValid, - CastTextToIntResultCode::TooLargeOrMalformed | CastTextToIntResultCode::SpecialCase, - ) => real_cast, - (CastTextToRealResultCode::NotValidButPrefix, _) => real_cast, - (CastTextToRealResultCode::PureInt, CastTextToIntResultCode::Success) => int_cast, - (CastTextToRealResultCode::HasDecimal, _) => real_cast, - _ => real_cast, - } +fn cast_text_to_numeric(txt: &str) -> OwnedValue { + checked_cast_text_to_numeric(txt).unwrap_or(OwnedValue::Integer(0)) } // Check if float can be losslessly converted to 51-bit integer