mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-28 13:34:24 +01:00
Merge 'Fix sum() to follow the SQLite semantics' from FamHaggs
### Follow SUM [spec](https://sqlite.org/lang_aggfunc.html) This PR updates the `SUM` aggregation logic to follow the [Kahan–Babushka–Neumaier summation algorithm](https://en.wikipedia.org/wiki/Kahan_summation_algorithm), consistent with SQLite’s implementation. It improves the numerical stability of floating-point summation.This fixes issue #2252 . I added a fuzz test to ensure the compatibility of the implementations I also fixed the return types for `SUM` to match SQLite’s documented behavior. This was previously discussed in [#2182](https://github.com/tursodatabase/turso/pull/2182), but part of the logic was later unintentionally overwritten by [#2265](https://github.com/tursodatabase/turso/pull/2265). I introduced two helper functions, `apply_kbn_step` and `apply_kbn_step_int`, in `vbde/execute.rs` to handle floating-point and integer accumulation respectively. However, I’m new to this codebase and would welcome constructive feedback on whether there’s a better place for these helpers. Reviewed-by: Preston Thorpe (@PThorpe92) Closes #2270
This commit is contained in:
@@ -259,6 +259,13 @@ impl Value {
|
||||
_ => panic!("as_blob must be called only for Value::Blob"),
|
||||
}
|
||||
}
|
||||
pub fn as_float(&self) -> f64 {
|
||||
match self {
|
||||
Value::Float(f) => *f,
|
||||
Value::Integer(i) => *i as f64,
|
||||
_ => panic!("as_float must be called only for Value::Float or Value::Integer"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_text(text: &str) -> Self {
|
||||
Value::Text(Text::new(text))
|
||||
@@ -495,10 +502,26 @@ impl Value {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct SumAggState {
|
||||
pub r_err: f64, // Error term for Kahan-Babushka-Neumaier summation
|
||||
pub approx: bool, // True if any non-integer value was input to the sum
|
||||
pub ovrfl: bool, // Integer overflow seen
|
||||
}
|
||||
impl Default for SumAggState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
r_err: 0.0,
|
||||
approx: false,
|
||||
ovrfl: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AggContext {
|
||||
Avg(Value, Value), // acc and count
|
||||
Sum(Value),
|
||||
Sum(Value, SumAggState),
|
||||
Count(Value),
|
||||
Max(Option<Value>),
|
||||
Min(Option<Value>),
|
||||
@@ -522,7 +545,7 @@ impl AggContext {
|
||||
pub fn final_value(&self) -> &Value {
|
||||
match self {
|
||||
Self::Avg(acc, _count) => acc,
|
||||
Self::Sum(acc) => acc,
|
||||
Self::Sum(acc, _) => acc,
|
||||
Self::Count(count) => count,
|
||||
Self::Max(max) => max.as_ref().unwrap_or(&NULL),
|
||||
Self::Min(min) => min.as_ref().unwrap_or(&NULL),
|
||||
@@ -596,7 +619,7 @@ impl PartialOrd<AggContext> for AggContext {
|
||||
fn partial_cmp(&self, other: &AggContext) -> Option<std::cmp::Ordering> {
|
||||
match (self, other) {
|
||||
(Self::Avg(a, _), Self::Avg(b, _)) => a.partial_cmp(b),
|
||||
(Self::Sum(a), Self::Sum(b)) => a.partial_cmp(b),
|
||||
(Self::Sum(a, _), Self::Sum(b, _)) => a.partial_cmp(b),
|
||||
(Self::Count(a), Self::Count(b)) => a.partial_cmp(b),
|
||||
(Self::Max(a), Self::Max(b)) => a.partial_cmp(b),
|
||||
(Self::Min(a), Self::Min(b)) => a.partial_cmp(b),
|
||||
|
||||
@@ -45,7 +45,10 @@ use crate::{
|
||||
|
||||
use crate::{
|
||||
storage::wal::CheckpointResult,
|
||||
types::{AggContext, Cursor, ExternalAggState, IOResult, SeekKey, SeekOp, Value, ValueType},
|
||||
types::{
|
||||
AggContext, Cursor, ExternalAggState, IOResult, SeekKey, SeekOp, SumAggState, Value,
|
||||
ValueType,
|
||||
},
|
||||
util::{
|
||||
cast_real_to_integer, cast_text_to_integer, cast_text_to_numeric, cast_text_to_real,
|
||||
checked_cast_text_to_numeric, parse_schema_rows, RoundToPrecision,
|
||||
@@ -3044,6 +3047,33 @@ pub fn op_decr_jump_zero(
|
||||
Ok(InsnFunctionStepResult::Step)
|
||||
}
|
||||
|
||||
fn apply_kbn_step(acc: &mut Value, r: f64, state: &mut SumAggState) {
|
||||
let s = acc.as_float();
|
||||
let t = s + r;
|
||||
let correction = if s.abs() > r.abs() {
|
||||
(s - t) + r
|
||||
} else {
|
||||
(r - t) + s
|
||||
};
|
||||
state.r_err += correction;
|
||||
*acc = Value::Float(t);
|
||||
}
|
||||
|
||||
// Add a (possibly large) integer to the running sum.
|
||||
fn apply_kbn_step_int(acc: &mut Value, i: i64, state: &mut SumAggState) {
|
||||
const THRESHOLD: i64 = 4503599627370496; // 2^52
|
||||
|
||||
if i <= -THRESHOLD || i >= THRESHOLD {
|
||||
let i_sm = i % 16384;
|
||||
let i_big = i - i_sm;
|
||||
|
||||
apply_kbn_step(acc, i_big as f64, state);
|
||||
apply_kbn_step(acc, i_sm as f64, state);
|
||||
} else {
|
||||
apply_kbn_step(acc, i as f64, state);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn op_agg_step(
|
||||
program: &Program,
|
||||
state: &mut ProgramState,
|
||||
@@ -3065,12 +3095,14 @@ pub fn op_agg_step(
|
||||
AggFunc::Avg => {
|
||||
Register::Aggregate(AggContext::Avg(Value::Float(0.0), Value::Integer(0)))
|
||||
}
|
||||
AggFunc::Sum => Register::Aggregate(AggContext::Sum(Value::Null)),
|
||||
AggFunc::Sum => {
|
||||
Register::Aggregate(AggContext::Sum(Value::Null, SumAggState::default()))
|
||||
}
|
||||
AggFunc::Total => {
|
||||
// The result of total() is always a floating point value.
|
||||
// No overflow error is ever raised if any prior input was a floating point value.
|
||||
// Total() never throws an integer overflow.
|
||||
Register::Aggregate(AggContext::Sum(Value::Float(0.0)))
|
||||
Register::Aggregate(AggContext::Sum(Value::Float(0.0), SumAggState::default()))
|
||||
}
|
||||
AggFunc::Count | AggFunc::Count0 => {
|
||||
Register::Aggregate(AggContext::Count(Value::Integer(0)))
|
||||
@@ -3128,25 +3160,62 @@ pub fn op_agg_step(
|
||||
state.registers[*acc_reg], *acc_reg
|
||||
);
|
||||
};
|
||||
let AggContext::Sum(acc) = agg.borrow_mut() else {
|
||||
let AggContext::Sum(acc, sum_state) = agg.borrow_mut() else {
|
||||
unreachable!();
|
||||
};
|
||||
match col {
|
||||
Register::Value(owned_value) => {
|
||||
match owned_value {
|
||||
Value::Integer(_) | Value::Float(_) => match acc {
|
||||
Value::Null => *acc = owned_value.clone(),
|
||||
_ => *acc += owned_value,
|
||||
},
|
||||
Value::Null => {
|
||||
// Null values are ignored in sum
|
||||
// Ignore NULLs
|
||||
}
|
||||
_ => match acc {
|
||||
Value::Null => *acc = Value::Float(0.0),
|
||||
Value::Integer(i) => *acc = Value::Float(*i as f64 + 0.0),
|
||||
Value::Float(_) => {}
|
||||
|
||||
Value::Integer(i) => match acc {
|
||||
Value::Null => {
|
||||
*acc = Value::Integer(i);
|
||||
}
|
||||
Value::Integer(acc_i) => {
|
||||
match acc_i.checked_add(i) {
|
||||
Some(sum) => *acc = Value::Integer(sum),
|
||||
None => {
|
||||
// Overflow -> switch to float with KBN summation
|
||||
let acc_f = *acc_i as f64;
|
||||
*acc = Value::Float(acc_f);
|
||||
sum_state.approx = true;
|
||||
sum_state.ovrfl = true;
|
||||
|
||||
apply_kbn_step_int(acc, i, sum_state);
|
||||
}
|
||||
}
|
||||
}
|
||||
Value::Float(_) => {
|
||||
apply_kbn_step_int(acc, i, sum_state);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
|
||||
Value::Float(f) => match acc {
|
||||
Value::Null => {
|
||||
*acc = Value::Float(f);
|
||||
}
|
||||
Value::Integer(i) => {
|
||||
let i_f = *i as f64;
|
||||
*acc = Value::Float(i_f);
|
||||
sum_state.approx = true;
|
||||
apply_kbn_step(acc, f, sum_state);
|
||||
}
|
||||
Value::Float(_) => {
|
||||
sum_state.approx = true;
|
||||
apply_kbn_step(acc, f, sum_state);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
|
||||
_ => {
|
||||
// If any input to sum() is neither an integer nor a NULL, then sum() returns a float
|
||||
// https://sqlite.org/lang_aggfunc.html
|
||||
sum_state.approx = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
@@ -3330,18 +3399,23 @@ pub fn op_agg_final(
|
||||
state.registers[*register] = Register::Value(acc.clone());
|
||||
}
|
||||
AggFunc::Sum => {
|
||||
let AggContext::Sum(acc) = agg.borrow_mut() else {
|
||||
let AggContext::Sum(acc, sum_state) = agg.borrow_mut() else {
|
||||
unreachable!();
|
||||
};
|
||||
let value = match acc {
|
||||
Value::Null => Value::Null,
|
||||
v @ Value::Integer(_) | v @ Value::Float(_) => v.clone(),
|
||||
_ => unreachable!(),
|
||||
Value::Null => match sum_state.approx {
|
||||
true => Value::Float(0.0),
|
||||
false => Value::Null,
|
||||
},
|
||||
Value::Integer(i) if !sum_state.approx && !sum_state.ovrfl => {
|
||||
Value::Integer(*i)
|
||||
}
|
||||
_ => Value::Float(acc.as_float() + sum_state.r_err),
|
||||
};
|
||||
state.registers[*register] = Register::Value(value);
|
||||
}
|
||||
AggFunc::Total => {
|
||||
let AggContext::Sum(acc) = agg.borrow_mut() else {
|
||||
let AggContext::Sum(acc, _) = agg.borrow_mut() else {
|
||||
unreachable!();
|
||||
};
|
||||
let value = match acc {
|
||||
|
||||
@@ -7,7 +7,7 @@ mod tests {
|
||||
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_chacha::ChaCha8Rng;
|
||||
use rusqlite::params;
|
||||
use rusqlite::{params, types::Value};
|
||||
|
||||
use crate::{
|
||||
common::{limbo_exec_rows, rng_from_time, sqlite_exec_rows, TempDatabase},
|
||||
@@ -1417,6 +1417,51 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
// Simple fuzz test for SUM with floats
|
||||
pub fn sum_agg_fuzz_floats() {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let (mut rng, seed) = rng_from_time();
|
||||
log::info!("seed: {seed}");
|
||||
|
||||
for _ in 0..100 {
|
||||
let db = TempDatabase::new_empty(false);
|
||||
let limbo_conn = db.connect_limbo();
|
||||
let sqlite_conn = rusqlite::Connection::open_in_memory().unwrap();
|
||||
|
||||
limbo_exec_rows(&db, &limbo_conn, "CREATE TABLE t(x)");
|
||||
sqlite_exec_rows(&sqlite_conn, "CREATE TABLE t(x)");
|
||||
|
||||
// Insert 50-100 mixed values: floats, text, NULL
|
||||
let mut values = Vec::new();
|
||||
for _ in 0..rng.random_range(50..=100) {
|
||||
let value = rng.random_range(-100.0..100.0).to_string();
|
||||
values.push(format!("({value})"));
|
||||
}
|
||||
|
||||
let insert = format!("INSERT INTO t VALUES {}", values.join(","));
|
||||
limbo_exec_rows(&db, &limbo_conn, &insert);
|
||||
sqlite_exec_rows(&sqlite_conn, &insert);
|
||||
|
||||
let query = "SELECT sum(x) FROM t ORDER BY x";
|
||||
let limbo_result = limbo_exec_rows(&db, &limbo_conn, query);
|
||||
let sqlite_result = sqlite_exec_rows(&sqlite_conn, query);
|
||||
|
||||
let limbo_val = match limbo_result.first().and_then(|row| row.first()) {
|
||||
Some(Value::Real(f)) => *f,
|
||||
Some(Value::Null) | None => 0.0,
|
||||
_ => panic!("Unexpected type in limbo result: {limbo_result:?}"),
|
||||
};
|
||||
|
||||
let sqlite_val = match sqlite_result.first().and_then(|row| row.first()) {
|
||||
Some(Value::Real(f)) => *f,
|
||||
Some(Value::Null) | None => 0.0,
|
||||
_ => panic!("Unexpected type in limbo result: {limbo_result:?}"),
|
||||
};
|
||||
assert_eq!(limbo_val, sqlite_val, "seed: {seed}, values: {values:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Simple fuzz test for SUM with mixed numeric/non-numeric values (issue #2133)
|
||||
|
||||
Reference in New Issue
Block a user