Fix vdbe casting and rounding issues

This commit is contained in:
PThorpe92
2025-02-23 12:53:36 -05:00
parent 8f27a5fc92
commit 8070e51e26
2 changed files with 144 additions and 179 deletions

View File

@@ -5,6 +5,18 @@ use crate::storage::wal::CheckpointMode;
use crate::types::{OwnedValue, Record};
use limbo_macros::Description;
macro_rules! final_agg_values {
($var:ident) => {
if let OwnedValue::Agg(agg) = $var {
$var = agg.final_value();
}
};
($lhs:ident, $rhs:ident) => {
final_agg_values!($lhs);
final_agg_values!($rhs);
};
}
/// Flags provided to comparison instructions (e.g. Eq, Ne) which determine behavior related to NULL values.
#[derive(Clone, Copy, Debug, Default)]
pub struct CmpInsFlags(usize);
@@ -688,42 +700,8 @@ pub enum Cookie {
UserVersion = 6,
}
fn cast_text_to_numerical(value: &str) -> OwnedValue {
if let Ok(x) = value.parse::<i64>() {
OwnedValue::Integer(x)
} else if let Ok(x) = value.parse::<f64>() {
OwnedValue::Float(x)
} else {
OwnedValue::Integer(0)
}
}
fn cast_text_to_numerical(value: &str) -> OwnedValue {
if let Ok(x) = value.parse::<i64>() {
OwnedValue::Integer(x)
} else if let Ok(x) = value.parse::<f64>() {
OwnedValue::Float(x)
} else {
let idx = value
.chars()
.enumerate()
.find_map(|(i, c)| match i {
i if i == 0 && c == '-' => None,
i if i > 0 && !c.is_ascii_digit() => Some(i),
_ => None,
})
.unwrap_or(0);
OwnedValue::Integer(value[0..idx].parse::<i64>().unwrap_or(0))
}
}
pub fn exec_add(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = lhs {
lhs = agg.final_value();
}
if let OwnedValue::Agg(agg) = rhs {
rhs = agg.final_value();
}
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => {
let result = lhs.overflowing_add(*rhs);
@@ -733,7 +711,9 @@ pub fn exec_add(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
OwnedValue::Integer(result.0)
}
}
(OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs + rhs),
(OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => {
OwnedValue::Float((lhs + rhs).round_to_precision(6))
}
(OwnedValue::Float(f), OwnedValue::Integer(i))
| (OwnedValue::Integer(i), OwnedValue::Float(f)) => OwnedValue::Float(*f + *i as f64),
(OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null,
@@ -749,12 +729,7 @@ pub fn exec_add(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
}
pub fn exec_subtract(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = lhs {
lhs = agg.final_value();
}
if let OwnedValue::Agg(agg) = rhs {
rhs = agg.final_value();
}
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => {
let result = lhs.overflowing_sub(*rhs);
@@ -782,12 +757,7 @@ pub fn exec_subtract(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
}
}
pub fn exec_multiply(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = lhs {
lhs = agg.final_value();
}
if let OwnedValue::Agg(agg) = rhs {
rhs = agg.final_value();
}
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => {
let result = lhs.overflowing_mul(*rhs);
@@ -797,7 +767,9 @@ pub fn exec_multiply(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
OwnedValue::Integer(result.0)
}
}
(OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => OwnedValue::Float(lhs * rhs),
(OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => {
OwnedValue::Float((lhs * rhs).round_to_precision(6))
}
(OwnedValue::Integer(i), OwnedValue::Float(f))
| (OwnedValue::Float(f), OwnedValue::Integer(i)) => OwnedValue::Float(*i as f64 * { *f }),
(OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null,
@@ -814,12 +786,7 @@ pub fn exec_multiply(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
}
pub fn exec_divide(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = lhs {
lhs = agg.final_value();
}
if let OwnedValue::Agg(agg) = rhs {
rhs = agg.final_value();
}
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(_, OwnedValue::Integer(0)) | (_, OwnedValue::Float(0.0)) => OwnedValue::Null,
(OwnedValue::Integer(lhs), OwnedValue::Integer(rhs)) => {
@@ -845,12 +812,7 @@ pub fn exec_divide(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
}
pub fn exec_bit_and(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = lhs {
lhs = agg.final_value();
}
if let OwnedValue::Agg(agg) = rhs {
rhs = agg.final_value();
}
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null,
(_, OwnedValue::Integer(0))
@@ -875,12 +837,7 @@ pub fn exec_bit_and(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
}
pub fn exec_bit_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = lhs {
lhs = agg.final_value();
}
if let OwnedValue::Agg(agg) = rhs {
rhs = agg.final_value();
}
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null,
(OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => OwnedValue::Integer(lh | rh),
@@ -901,12 +858,7 @@ pub fn exec_bit_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
}
pub fn exec_remainder(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = lhs {
lhs = agg.final_value();
}
if let OwnedValue::Agg(agg) = rhs {
rhs = agg.final_value();
}
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(OwnedValue::Null, _)
| (_, OwnedValue::Null)
@@ -947,9 +899,7 @@ pub fn exec_remainder(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue
}
pub fn exec_bit_not(mut reg: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = reg {
reg = agg.final_value();
}
final_agg_values!(reg);
match reg {
OwnedValue::Null => OwnedValue::Null,
OwnedValue::Integer(i) => OwnedValue::Integer(!i),
@@ -960,12 +910,7 @@ pub fn exec_bit_not(mut reg: &OwnedValue) -> OwnedValue {
}
pub fn exec_shift_left(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = lhs {
lhs = agg.final_value();
}
if let OwnedValue::Agg(agg) = rhs {
rhs = agg.final_value();
}
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null,
(OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => {
@@ -999,12 +944,7 @@ fn compute_shl(lhs: i64, rhs: i64) -> i64 {
}
pub fn exec_shift_right(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = lhs {
lhs = agg.final_value();
}
if let OwnedValue::Agg(agg) = rhs {
rhs = agg.final_value();
}
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null,
(OwnedValue::Integer(lh), OwnedValue::Integer(rh)) => {
@@ -1051,9 +991,7 @@ fn compute_shr(lhs: i64, rhs: i64) -> i64 {
}
pub fn exec_boolean_not(mut reg: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = reg {
reg = agg.final_value();
}
final_agg_values!(reg);
match reg {
OwnedValue::Null => OwnedValue::Null,
OwnedValue::Integer(i) => OwnedValue::Integer((*i == 0) as i64),
@@ -1062,8 +1000,8 @@ pub fn exec_boolean_not(mut reg: &OwnedValue) -> OwnedValue {
_ => todo!(),
}
}
pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue {
pub fn exec_concat(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(OwnedValue::Text(lhs_text), OwnedValue::Text(rhs_text)) => {
OwnedValue::build_text(&(lhs_text.as_str().to_string() + rhs_text.as_str()))
@@ -1074,10 +1012,6 @@ pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue {
(OwnedValue::Text(lhs_text), OwnedValue::Float(rhs_float)) => {
OwnedValue::build_text(&(lhs_text.as_str().to_string() + &rhs_float.to_string()))
}
(OwnedValue::Text(lhs_text), OwnedValue::Agg(rhs_agg)) => OwnedValue::build_text(
(lhs_text.as_str().to_string() + &rhs_agg.final_value().to_string()).as_str(),
),
(OwnedValue::Integer(lhs_int), OwnedValue::Text(rhs_text)) => {
OwnedValue::build_text(&(lhs_int.to_string() + rhs_text.as_str()))
}
@@ -1087,10 +1021,6 @@ pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue {
(OwnedValue::Integer(lhs_int), OwnedValue::Float(rhs_float)) => {
OwnedValue::build_text(&(lhs_int.to_string() + &rhs_float.to_string()))
}
(OwnedValue::Integer(lhs_int), OwnedValue::Agg(rhs_agg)) => {
OwnedValue::build_text(&(lhs_int.to_string() + &rhs_agg.final_value().to_string()))
}
(OwnedValue::Float(lhs_float), OwnedValue::Text(rhs_text)) => {
OwnedValue::build_text(&(lhs_float.to_string() + rhs_text.as_str()))
}
@@ -1100,39 +1030,19 @@ pub fn exec_concat(lhs: &OwnedValue, rhs: &OwnedValue) -> OwnedValue {
(OwnedValue::Float(lhs_float), OwnedValue::Float(rhs_float)) => {
OwnedValue::build_text(&(lhs_float.to_string() + &rhs_float.to_string()))
}
(OwnedValue::Float(lhs_float), OwnedValue::Agg(rhs_agg)) => {
OwnedValue::build_text(&(lhs_float.to_string() + &rhs_agg.final_value().to_string()))
}
(OwnedValue::Agg(lhs_agg), OwnedValue::Text(rhs_text)) => {
OwnedValue::build_text(&(lhs_agg.final_value().to_string() + rhs_text.as_str()))
}
(OwnedValue::Agg(lhs_agg), OwnedValue::Integer(rhs_int)) => {
OwnedValue::build_text(&(lhs_agg.final_value().to_string() + &rhs_int.to_string()))
}
(OwnedValue::Agg(lhs_agg), OwnedValue::Float(rhs_float)) => {
OwnedValue::build_text(&(lhs_agg.final_value().to_string() + &rhs_float.to_string()))
}
(OwnedValue::Agg(lhs_agg), OwnedValue::Agg(rhs_agg)) => OwnedValue::build_text(
&(lhs_agg.final_value().to_string() + &rhs_agg.final_value().to_string()),
),
(OwnedValue::Null, _) | (_, OwnedValue::Null) => OwnedValue::Null,
(OwnedValue::Blob(_), _) | (_, OwnedValue::Blob(_)) => {
todo!("TODO: Handle Blob conversion to String")
}
(OwnedValue::Record(_), _) | (_, OwnedValue::Record(_)) => unreachable!(),
(OwnedValue::Record(_), _)
| (_, OwnedValue::Record(_))
| (OwnedValue::Agg(_), _)
| (_, OwnedValue::Agg(_)) => unreachable!(),
}
}
pub fn exec_and(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = lhs {
lhs = agg.final_value();
}
if let OwnedValue::Agg(agg) = rhs {
rhs = agg.final_value();
}
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(_, OwnedValue::Integer(0))
| (OwnedValue::Integer(0), _)
@@ -1151,13 +1061,7 @@ pub fn exec_and(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
}
pub fn exec_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
if let OwnedValue::Agg(agg) = lhs {
lhs = agg.final_value();
}
if let OwnedValue::Agg(agg) = rhs {
rhs = agg.final_value();
}
final_agg_values!(lhs, rhs);
match (lhs, rhs) {
(OwnedValue::Null, OwnedValue::Null)
| (OwnedValue::Null, OwnedValue::Float(0.0))
@@ -1179,6 +1083,17 @@ pub fn exec_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
}
}
trait RoundToPrecision {
fn round_to_precision(self, precision: i32) -> f64;
}
impl RoundToPrecision for f64 {
fn round_to_precision(self, precision: i32) -> f64 {
let factor = 10f64.powi(precision);
(self * factor).round() / factor
}
}
#[cfg(test)]
mod tests {
use crate::{

View File

@@ -2592,10 +2592,13 @@ impl Program {
),
},
OwnedValue::Text(text) => {
match checked_cast_text_to_numeric(&text.as_str()) {
match checked_cast_text_to_numeric(text.as_str()) {
Ok(OwnedValue::Integer(i)) => {
state.registers[*reg] = OwnedValue::Integer(i)
}
Ok(OwnedValue::Float(f)) => {
state.registers[*reg] = OwnedValue::Integer(f as i64)
}
_ => crate::bail_parse_error!(
"MustBeInt: the value in register cannot be cast to integer"
),
@@ -3656,24 +3659,63 @@ fn cast_text_to_integer(text: &str) -> OwnedValue {
if let Ok(i) = text.parse::<i64>() {
return OwnedValue::Integer(i);
}
let idx = text
.chars()
.enumerate()
.find_map(|(i, c)| match i {
i if i == 0 && c == '-' => None,
i if i > 0 && !c.is_ascii_digit() => Some(i),
_ => None,
})
.unwrap_or(0);
OwnedValue::Integer(text[0..idx].parse::<i64>().unwrap_or(0))
let bytes = text.as_bytes();
let mut end = 0;
if bytes[0] == b'-' {
end = 1;
}
while end < bytes.len() && bytes[end].is_ascii_digit() {
end += 1;
}
text[..end]
.parse::<i64>()
.map_or(OwnedValue::Integer(0), OwnedValue::Integer)
}
/// When casting a TEXT value to REAL, the longest possible prefix of the value that can be interpreted
/// as a real number is extracted from the TEXT value and the remainder ignored. Any leading spaces in
/// the TEXT value are ignored when converging from TEXT to REAL.
/// If there is no prefix that can be interpreted as a real number, the result of the conversion is 0.0.
fn cast_text_to_real(text: &str) -> (OwnedValue, CastTextToRealResultCode) {
text_to_real(text)
fn cast_text_to_real(text: &str) -> OwnedValue {
let trimmed = text.trim_start();
if trimmed.is_empty() {
return OwnedValue::Float(0.0);
}
if let Ok(num) = trimmed.parse::<f64>() {
return OwnedValue::Float(num);
}
let bytes = text.as_bytes();
let mut end = 0;
let mut has_decimal = false;
let mut has_exponent = false;
if bytes[0] == b'-' {
end = 1;
}
while end < bytes.len() {
match bytes[end] {
b'0'..=b'9' => end += 1,
b'.' if !has_decimal && !has_exponent => {
has_decimal = true;
end += 1;
}
b'e' | b'E' if !has_exponent => {
has_exponent = true;
end += 1;
// allow exponent sign
if end < bytes.len() && (bytes[end] == b'+' || bytes[end] == b'-') {
end += 1;
}
}
_ => break,
}
}
if end == 0 || (end == 1 && bytes[0] == b'-') {
return OwnedValue::Float(0.0);
}
text[..end]
.parse::<f64>()
.map_or(OwnedValue::Float(0.0), OwnedValue::Float)
}
/// NUMERIC Casting a TEXT or BLOB value into NUMERIC yields either an INTEGER or a REAL result.
@@ -3686,44 +3728,52 @@ fn cast_text_to_real(text: &str) -> (OwnedValue, CastTextToRealResultCode) {
/// IEEE 754 64-bit float and thus provides a 1-bit of margin for the text-to-float conversion operation.)
/// Any text input that describes a value outside the range of a 64-bit signed integer yields a REAL result.
/// Casting a REAL or INTEGER value to NUMERIC is a no-op, even if a real value could be losslessly converted to an integer.
fn checked_cast_text_to_numeric(text: &str) -> std::result::Result<OwnedValue, ()> {
if !text.contains('.') && !text.contains('e') && !text.contains('E') {
// Looks like an integer
if let Ok(i) = text.parse::<i64>() {
return Ok(OwnedValue::Integer(i));
pub fn checked_cast_text_to_numeric(text: &str) -> std::result::Result<OwnedValue, ()> {
// sqlite will parse the first N digits of a string to numeric value, then determine
// whether _that_ value is more likely a real or integer value. e.g.
// '-100234-2344.23e14' evaluates to -100234 instead of -100234.0
let bytes = text.as_bytes();
let mut end = 0;
let mut has_decimal = false;
let mut has_exponent = false;
if bytes[0] == b'-' {
end = 1;
}
while end < bytes.len() {
match bytes[end] {
b'0'..=b'9' => end += 1,
b'.' if !has_decimal && !has_exponent => {
has_decimal = true;
end += 1;
}
b'e' | b'E' if !has_exponent => {
has_exponent = true;
end += 1;
// allow exponent sign
if end < bytes.len() && (bytes[end] == b'+' || bytes[end] == b'-') {
end += 1;
}
}
_ => break,
}
}
// Try as float
if let Ok(f) = text.parse::<f64>() {
return match cast_real_to_integer(f) {
Ok(i) => Ok(OwnedValue::Integer(i)),
Err(_) => Ok(OwnedValue::Float(f)),
};
if end == 0 || (end == 1 && bytes[0] == b'-') {
return Err(());
}
let text = &text[..end];
if !has_decimal && !has_exponent {
Ok(text
.parse::<i64>()
.map_or(OwnedValue::Integer(0), OwnedValue::Integer))
} else {
Ok(text
.parse::<f64>()
.map_or(OwnedValue::Float(0.0), OwnedValue::Float))
}
Err(())
}
/// Reference for function definition
/// https://github.com/sqlite/sqlite/blob/eb3a069fc82e53a40ea63076d66ab113a3b2b0c6/src/vdbe.c#L465
fn cast_text_to_numeric(text: &str) -> OwnedValue {
let (real_cast, rc_real) = cast_text_to_real(text);
let (int_cast, rc_int) = cast_text_to_integer(text);
match (rc_real, rc_int) {
(
CastTextToRealResultCode::NotValid,
CastTextToIntResultCode::ExcessSpace
| CastTextToIntResultCode::Success
| CastTextToIntResultCode::NotInt,
) => int_cast,
(
CastTextToRealResultCode::NotValid,
CastTextToIntResultCode::TooLargeOrMalformed | CastTextToIntResultCode::SpecialCase,
) => real_cast,
(CastTextToRealResultCode::NotValidButPrefix, _) => real_cast,
(CastTextToRealResultCode::PureInt, CastTextToIntResultCode::Success) => int_cast,
(CastTextToRealResultCode::HasDecimal, _) => real_cast,
_ => real_cast,
}
fn cast_text_to_numeric(txt: &str) -> OwnedValue {
checked_cast_text_to_numeric(txt).unwrap_or(OwnedValue::Integer(0))
}
// Check if float can be losslessly converted to 51-bit integer