Merge 'Fix sum() to follow the SQLite semantics' from FamHaggs

### Follow SUM [spec](https://sqlite.org/lang_aggfunc.html)
This PR updates the `SUM` aggregation logic to follow the
[Kahan–Babushka–Neumaier summation
algorithm](https://en.wikipedia.org/wiki/Kahan_summation_algorithm),
consistent with SQLite’s implementation. It improves the numerical
stability of floating-point summation.This fixes issue #2252 . I added a
fuzz test to ensure the compatibility of the implementations
I also fixed the return types for `SUM` to match SQLite’s documented
behavior. This was previously discussed in
[#2182](https://github.com/tursodatabase/turso/pull/2182), but part of
the logic was later unintentionally overwritten by
[#2265](https://github.com/tursodatabase/turso/pull/2265).
I introduced two helper functions, `apply_kbn_step` and
`apply_kbn_step_int`, in `vbde/execute.rs` to handle floating-point and
integer accumulation respectively. However, I’m new to this codebase and
would welcome constructive feedback on whether there’s a better place
for these helpers.

Reviewed-by: Preston Thorpe (@PThorpe92)

Closes #2270
This commit is contained in:
Pekka Enberg
2025-07-27 09:08:34 +03:00
3 changed files with 164 additions and 22 deletions

View File

@@ -259,6 +259,13 @@ impl Value {
_ => panic!("as_blob must be called only for Value::Blob"),
}
}
pub fn as_float(&self) -> f64 {
match self {
Value::Float(f) => *f,
Value::Integer(i) => *i as f64,
_ => panic!("as_float must be called only for Value::Float or Value::Integer"),
}
}
pub fn from_text(text: &str) -> Self {
Value::Text(Text::new(text))
@@ -495,10 +502,26 @@ impl Value {
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct SumAggState {
pub r_err: f64, // Error term for Kahan-Babushka-Neumaier summation
pub approx: bool, // True if any non-integer value was input to the sum
pub ovrfl: bool, // Integer overflow seen
}
impl Default for SumAggState {
fn default() -> Self {
Self {
r_err: 0.0,
approx: false,
ovrfl: false,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum AggContext {
Avg(Value, Value), // acc and count
Sum(Value),
Sum(Value, SumAggState),
Count(Value),
Max(Option<Value>),
Min(Option<Value>),
@@ -522,7 +545,7 @@ impl AggContext {
pub fn final_value(&self) -> &Value {
match self {
Self::Avg(acc, _count) => acc,
Self::Sum(acc) => acc,
Self::Sum(acc, _) => acc,
Self::Count(count) => count,
Self::Max(max) => max.as_ref().unwrap_or(&NULL),
Self::Min(min) => min.as_ref().unwrap_or(&NULL),
@@ -596,7 +619,7 @@ impl PartialOrd<AggContext> for AggContext {
fn partial_cmp(&self, other: &AggContext) -> Option<std::cmp::Ordering> {
match (self, other) {
(Self::Avg(a, _), Self::Avg(b, _)) => a.partial_cmp(b),
(Self::Sum(a), Self::Sum(b)) => a.partial_cmp(b),
(Self::Sum(a, _), Self::Sum(b, _)) => a.partial_cmp(b),
(Self::Count(a), Self::Count(b)) => a.partial_cmp(b),
(Self::Max(a), Self::Max(b)) => a.partial_cmp(b),
(Self::Min(a), Self::Min(b)) => a.partial_cmp(b),

View File

@@ -45,7 +45,10 @@ use crate::{
use crate::{
storage::wal::CheckpointResult,
types::{AggContext, Cursor, ExternalAggState, IOResult, SeekKey, SeekOp, Value, ValueType},
types::{
AggContext, Cursor, ExternalAggState, IOResult, SeekKey, SeekOp, SumAggState, Value,
ValueType,
},
util::{
cast_real_to_integer, cast_text_to_integer, cast_text_to_numeric, cast_text_to_real,
checked_cast_text_to_numeric, parse_schema_rows, RoundToPrecision,
@@ -3044,6 +3047,33 @@ pub fn op_decr_jump_zero(
Ok(InsnFunctionStepResult::Step)
}
fn apply_kbn_step(acc: &mut Value, r: f64, state: &mut SumAggState) {
let s = acc.as_float();
let t = s + r;
let correction = if s.abs() > r.abs() {
(s - t) + r
} else {
(r - t) + s
};
state.r_err += correction;
*acc = Value::Float(t);
}
// Add a (possibly large) integer to the running sum.
fn apply_kbn_step_int(acc: &mut Value, i: i64, state: &mut SumAggState) {
const THRESHOLD: i64 = 4503599627370496; // 2^52
if i <= -THRESHOLD || i >= THRESHOLD {
let i_sm = i % 16384;
let i_big = i - i_sm;
apply_kbn_step(acc, i_big as f64, state);
apply_kbn_step(acc, i_sm as f64, state);
} else {
apply_kbn_step(acc, i as f64, state);
}
}
pub fn op_agg_step(
program: &Program,
state: &mut ProgramState,
@@ -3065,12 +3095,14 @@ pub fn op_agg_step(
AggFunc::Avg => {
Register::Aggregate(AggContext::Avg(Value::Float(0.0), Value::Integer(0)))
}
AggFunc::Sum => Register::Aggregate(AggContext::Sum(Value::Null)),
AggFunc::Sum => {
Register::Aggregate(AggContext::Sum(Value::Null, SumAggState::default()))
}
AggFunc::Total => {
// The result of total() is always a floating point value.
// No overflow error is ever raised if any prior input was a floating point value.
// Total() never throws an integer overflow.
Register::Aggregate(AggContext::Sum(Value::Float(0.0)))
Register::Aggregate(AggContext::Sum(Value::Float(0.0), SumAggState::default()))
}
AggFunc::Count | AggFunc::Count0 => {
Register::Aggregate(AggContext::Count(Value::Integer(0)))
@@ -3128,25 +3160,62 @@ pub fn op_agg_step(
state.registers[*acc_reg], *acc_reg
);
};
let AggContext::Sum(acc) = agg.borrow_mut() else {
let AggContext::Sum(acc, sum_state) = agg.borrow_mut() else {
unreachable!();
};
match col {
Register::Value(owned_value) => {
match owned_value {
Value::Integer(_) | Value::Float(_) => match acc {
Value::Null => *acc = owned_value.clone(),
_ => *acc += owned_value,
},
Value::Null => {
// Null values are ignored in sum
// Ignore NULLs
}
_ => match acc {
Value::Null => *acc = Value::Float(0.0),
Value::Integer(i) => *acc = Value::Float(*i as f64 + 0.0),
Value::Float(_) => {}
Value::Integer(i) => match acc {
Value::Null => {
*acc = Value::Integer(i);
}
Value::Integer(acc_i) => {
match acc_i.checked_add(i) {
Some(sum) => *acc = Value::Integer(sum),
None => {
// Overflow -> switch to float with KBN summation
let acc_f = *acc_i as f64;
*acc = Value::Float(acc_f);
sum_state.approx = true;
sum_state.ovrfl = true;
apply_kbn_step_int(acc, i, sum_state);
}
}
}
Value::Float(_) => {
apply_kbn_step_int(acc, i, sum_state);
}
_ => unreachable!(),
},
Value::Float(f) => match acc {
Value::Null => {
*acc = Value::Float(f);
}
Value::Integer(i) => {
let i_f = *i as f64;
*acc = Value::Float(i_f);
sum_state.approx = true;
apply_kbn_step(acc, f, sum_state);
}
Value::Float(_) => {
sum_state.approx = true;
apply_kbn_step(acc, f, sum_state);
}
_ => unreachable!(),
},
_ => {
// If any input to sum() is neither an integer nor a NULL, then sum() returns a float
// https://sqlite.org/lang_aggfunc.html
sum_state.approx = true;
}
}
}
_ => unreachable!(),
@@ -3330,18 +3399,23 @@ pub fn op_agg_final(
state.registers[*register] = Register::Value(acc.clone());
}
AggFunc::Sum => {
let AggContext::Sum(acc) = agg.borrow_mut() else {
let AggContext::Sum(acc, sum_state) = agg.borrow_mut() else {
unreachable!();
};
let value = match acc {
Value::Null => Value::Null,
v @ Value::Integer(_) | v @ Value::Float(_) => v.clone(),
_ => unreachable!(),
Value::Null => match sum_state.approx {
true => Value::Float(0.0),
false => Value::Null,
},
Value::Integer(i) if !sum_state.approx && !sum_state.ovrfl => {
Value::Integer(*i)
}
_ => Value::Float(acc.as_float() + sum_state.r_err),
};
state.registers[*register] = Register::Value(value);
}
AggFunc::Total => {
let AggContext::Sum(acc) = agg.borrow_mut() else {
let AggContext::Sum(acc, _) = agg.borrow_mut() else {
unreachable!();
};
let value = match acc {

View File

@@ -7,7 +7,7 @@ mod tests {
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use rusqlite::params;
use rusqlite::{params, types::Value};
use crate::{
common::{limbo_exec_rows, rng_from_time, sqlite_exec_rows, TempDatabase},
@@ -1417,6 +1417,51 @@ mod tests {
}
}
}
#[test]
// Simple fuzz test for SUM with floats
pub fn sum_agg_fuzz_floats() {
let _ = env_logger::try_init();
let (mut rng, seed) = rng_from_time();
log::info!("seed: {seed}");
for _ in 0..100 {
let db = TempDatabase::new_empty(false);
let limbo_conn = db.connect_limbo();
let sqlite_conn = rusqlite::Connection::open_in_memory().unwrap();
limbo_exec_rows(&db, &limbo_conn, "CREATE TABLE t(x)");
sqlite_exec_rows(&sqlite_conn, "CREATE TABLE t(x)");
// Insert 50-100 mixed values: floats, text, NULL
let mut values = Vec::new();
for _ in 0..rng.random_range(50..=100) {
let value = rng.random_range(-100.0..100.0).to_string();
values.push(format!("({value})"));
}
let insert = format!("INSERT INTO t VALUES {}", values.join(","));
limbo_exec_rows(&db, &limbo_conn, &insert);
sqlite_exec_rows(&sqlite_conn, &insert);
let query = "SELECT sum(x) FROM t ORDER BY x";
let limbo_result = limbo_exec_rows(&db, &limbo_conn, query);
let sqlite_result = sqlite_exec_rows(&sqlite_conn, query);
let limbo_val = match limbo_result.first().and_then(|row| row.first()) {
Some(Value::Real(f)) => *f,
Some(Value::Null) | None => 0.0,
_ => panic!("Unexpected type in limbo result: {limbo_result:?}"),
};
let sqlite_val = match sqlite_result.first().and_then(|row| row.first()) {
Some(Value::Real(f)) => *f,
Some(Value::Null) | None => 0.0,
_ => panic!("Unexpected type in limbo result: {limbo_result:?}"),
};
assert_eq!(limbo_val, sqlite_val, "seed: {seed}, values: {values:?}");
}
}
#[test]
// Simple fuzz test for SUM with mixed numeric/non-numeric values (issue #2133)