diff --git a/core/types.rs b/core/types.rs index 5cbe08152..b1cf976cf 100644 --- a/core/types.rs +++ b/core/types.rs @@ -259,6 +259,13 @@ impl Value { _ => panic!("as_blob must be called only for Value::Blob"), } } + pub fn as_float(&self) -> f64 { + match self { + Value::Float(f) => *f, + Value::Integer(i) => *i as f64, + _ => panic!("as_float must be called only for Value::Float or Value::Integer"), + } + } pub fn from_text(text: &str) -> Self { Value::Text(Text::new(text)) @@ -495,10 +502,26 @@ impl Value { } } +#[derive(Debug, Clone, PartialEq)] +pub struct SumAggState { + pub r_err: f64, // Error term for Kahan-Babushka-Neumaier summation + pub approx: bool, // True if any non-integer value was input to the sum + pub ovrfl: bool, // Integer overflow seen +} +impl Default for SumAggState { + fn default() -> Self { + Self { + r_err: 0.0, + approx: false, + ovrfl: false, + } + } +} + #[derive(Debug, Clone, PartialEq)] pub enum AggContext { Avg(Value, Value), // acc and count - Sum(Value), + Sum(Value, SumAggState), Count(Value), Max(Option), Min(Option), @@ -522,7 +545,7 @@ impl AggContext { pub fn final_value(&self) -> &Value { match self { Self::Avg(acc, _count) => acc, - Self::Sum(acc) => acc, + Self::Sum(acc, _) => acc, Self::Count(count) => count, Self::Max(max) => max.as_ref().unwrap_or(&NULL), Self::Min(min) => min.as_ref().unwrap_or(&NULL), @@ -596,7 +619,7 @@ impl PartialOrd for AggContext { fn partial_cmp(&self, other: &AggContext) -> Option { match (self, other) { (Self::Avg(a, _), Self::Avg(b, _)) => a.partial_cmp(b), - (Self::Sum(a), Self::Sum(b)) => a.partial_cmp(b), + (Self::Sum(a, _), Self::Sum(b, _)) => a.partial_cmp(b), (Self::Count(a), Self::Count(b)) => a.partial_cmp(b), (Self::Max(a), Self::Max(b)) => a.partial_cmp(b), (Self::Min(a), Self::Min(b)) => a.partial_cmp(b), diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index c79042aff..5041ff478 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -45,7 +45,10 @@ use crate::{ use crate::{ storage::wal::CheckpointResult, - types::{AggContext, Cursor, ExternalAggState, IOResult, SeekKey, SeekOp, Value, ValueType}, + types::{ + AggContext, Cursor, ExternalAggState, IOResult, SeekKey, SeekOp, SumAggState, Value, + ValueType, + }, util::{ cast_real_to_integer, cast_text_to_integer, cast_text_to_numeric, cast_text_to_real, checked_cast_text_to_numeric, parse_schema_rows, RoundToPrecision, @@ -3044,6 +3047,33 @@ pub fn op_decr_jump_zero( Ok(InsnFunctionStepResult::Step) } +fn apply_kbn_step(acc: &mut Value, r: f64, state: &mut SumAggState) { + let s = acc.as_float(); + let t = s + r; + let correction = if s.abs() > r.abs() { + (s - t) + r + } else { + (r - t) + s + }; + state.r_err += correction; + *acc = Value::Float(t); +} + +// Add a (possibly large) integer to the running sum. +fn apply_kbn_step_int(acc: &mut Value, i: i64, state: &mut SumAggState) { + const THRESHOLD: i64 = 4503599627370496; // 2^52 + + if i <= -THRESHOLD || i >= THRESHOLD { + let i_sm = i % 16384; + let i_big = i - i_sm; + + apply_kbn_step(acc, i_big as f64, state); + apply_kbn_step(acc, i_sm as f64, state); + } else { + apply_kbn_step(acc, i as f64, state); + } +} + pub fn op_agg_step( program: &Program, state: &mut ProgramState, @@ -3065,12 +3095,14 @@ pub fn op_agg_step( AggFunc::Avg => { Register::Aggregate(AggContext::Avg(Value::Float(0.0), Value::Integer(0))) } - AggFunc::Sum => Register::Aggregate(AggContext::Sum(Value::Null)), + AggFunc::Sum => { + Register::Aggregate(AggContext::Sum(Value::Null, SumAggState::default())) + } AggFunc::Total => { // The result of total() is always a floating point value. // No overflow error is ever raised if any prior input was a floating point value. // Total() never throws an integer overflow. - Register::Aggregate(AggContext::Sum(Value::Float(0.0))) + Register::Aggregate(AggContext::Sum(Value::Float(0.0), SumAggState::default())) } AggFunc::Count | AggFunc::Count0 => { Register::Aggregate(AggContext::Count(Value::Integer(0))) @@ -3128,25 +3160,62 @@ pub fn op_agg_step( state.registers[*acc_reg], *acc_reg ); }; - let AggContext::Sum(acc) = agg.borrow_mut() else { + let AggContext::Sum(acc, sum_state) = agg.borrow_mut() else { unreachable!(); }; match col { Register::Value(owned_value) => { match owned_value { - Value::Integer(_) | Value::Float(_) => match acc { - Value::Null => *acc = owned_value.clone(), - _ => *acc += owned_value, - }, Value::Null => { - // Null values are ignored in sum + // Ignore NULLs } - _ => match acc { - Value::Null => *acc = Value::Float(0.0), - Value::Integer(i) => *acc = Value::Float(*i as f64 + 0.0), - Value::Float(_) => {} + + Value::Integer(i) => match acc { + Value::Null => { + *acc = Value::Integer(i); + } + Value::Integer(acc_i) => { + match acc_i.checked_add(i) { + Some(sum) => *acc = Value::Integer(sum), + None => { + // Overflow -> switch to float with KBN summation + let acc_f = *acc_i as f64; + *acc = Value::Float(acc_f); + sum_state.approx = true; + sum_state.ovrfl = true; + + apply_kbn_step_int(acc, i, sum_state); + } + } + } + Value::Float(_) => { + apply_kbn_step_int(acc, i, sum_state); + } _ => unreachable!(), }, + + Value::Float(f) => match acc { + Value::Null => { + *acc = Value::Float(f); + } + Value::Integer(i) => { + let i_f = *i as f64; + *acc = Value::Float(i_f); + sum_state.approx = true; + apply_kbn_step(acc, f, sum_state); + } + Value::Float(_) => { + sum_state.approx = true; + apply_kbn_step(acc, f, sum_state); + } + _ => unreachable!(), + }, + + _ => { + // If any input to sum() is neither an integer nor a NULL, then sum() returns a float + // https://sqlite.org/lang_aggfunc.html + sum_state.approx = true; + } } } _ => unreachable!(), @@ -3330,18 +3399,23 @@ pub fn op_agg_final( state.registers[*register] = Register::Value(acc.clone()); } AggFunc::Sum => { - let AggContext::Sum(acc) = agg.borrow_mut() else { + let AggContext::Sum(acc, sum_state) = agg.borrow_mut() else { unreachable!(); }; let value = match acc { - Value::Null => Value::Null, - v @ Value::Integer(_) | v @ Value::Float(_) => v.clone(), - _ => unreachable!(), + Value::Null => match sum_state.approx { + true => Value::Float(0.0), + false => Value::Null, + }, + Value::Integer(i) if !sum_state.approx && !sum_state.ovrfl => { + Value::Integer(*i) + } + _ => Value::Float(acc.as_float() + sum_state.r_err), }; state.registers[*register] = Register::Value(value); } AggFunc::Total => { - let AggContext::Sum(acc) = agg.borrow_mut() else { + let AggContext::Sum(acc, _) = agg.borrow_mut() else { unreachable!(); }; let value = match acc { diff --git a/tests/integration/fuzz/mod.rs b/tests/integration/fuzz/mod.rs index f2620177f..66b7d2058 100644 --- a/tests/integration/fuzz/mod.rs +++ b/tests/integration/fuzz/mod.rs @@ -7,7 +7,7 @@ mod tests { use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; - use rusqlite::params; + use rusqlite::{params, types::Value}; use crate::{ common::{limbo_exec_rows, rng_from_time, sqlite_exec_rows, TempDatabase}, @@ -1417,6 +1417,51 @@ mod tests { } } } + #[test] + // Simple fuzz test for SUM with floats + pub fn sum_agg_fuzz_floats() { + let _ = env_logger::try_init(); + + let (mut rng, seed) = rng_from_time(); + log::info!("seed: {seed}"); + + for _ in 0..100 { + let db = TempDatabase::new_empty(false); + let limbo_conn = db.connect_limbo(); + let sqlite_conn = rusqlite::Connection::open_in_memory().unwrap(); + + limbo_exec_rows(&db, &limbo_conn, "CREATE TABLE t(x)"); + sqlite_exec_rows(&sqlite_conn, "CREATE TABLE t(x)"); + + // Insert 50-100 mixed values: floats, text, NULL + let mut values = Vec::new(); + for _ in 0..rng.random_range(50..=100) { + let value = rng.random_range(-100.0..100.0).to_string(); + values.push(format!("({value})")); + } + + let insert = format!("INSERT INTO t VALUES {}", values.join(",")); + limbo_exec_rows(&db, &limbo_conn, &insert); + sqlite_exec_rows(&sqlite_conn, &insert); + + let query = "SELECT sum(x) FROM t ORDER BY x"; + let limbo_result = limbo_exec_rows(&db, &limbo_conn, query); + let sqlite_result = sqlite_exec_rows(&sqlite_conn, query); + + let limbo_val = match limbo_result.first().and_then(|row| row.first()) { + Some(Value::Real(f)) => *f, + Some(Value::Null) | None => 0.0, + _ => panic!("Unexpected type in limbo result: {limbo_result:?}"), + }; + + let sqlite_val = match sqlite_result.first().and_then(|row| row.first()) { + Some(Value::Real(f)) => *f, + Some(Value::Null) | None => 0.0, + _ => panic!("Unexpected type in limbo result: {limbo_result:?}"), + }; + assert_eq!(limbo_val, sqlite_val, "seed: {seed}, values: {values:?}"); + } + } #[test] // Simple fuzz test for SUM with mixed numeric/non-numeric values (issue #2133)