Replicate the sqlite Kahan-Babaska-Neumaier algorithm

This commit is contained in:
FHaggs
2025-07-25 15:25:29 -03:00
parent f0ffff3c8e
commit 54edfa09d5
2 changed files with 108 additions and 28 deletions

View File

@@ -502,10 +502,26 @@ impl Value {
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct SumAggState {
pub r_err: f64, // Error term for Kahan-Babushka-Neumaier summation
pub approx: bool, // True if any non-integer value was input to the sum
pub ovrfl: bool, // Integer overflow seen
}
impl Default for SumAggState {
fn default() -> Self {
Self {
r_err: 0.0,
approx: false,
ovrfl: false,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum AggContext {
Avg(Value, Value), // acc and count
Sum(Value, f64), // Error term for Kahan-Babushka-Neumaier summation
Sum(Value, SumAggState),
Count(Value),
Max(Option<Value>),
Min(Option<Value>),

View File

@@ -45,7 +45,10 @@ use crate::{
use crate::{
storage::wal::CheckpointResult,
types::{AggContext, Cursor, ExternalAggState, IOResult, SeekKey, SeekOp, Value, ValueType},
types::{
AggContext, Cursor, ExternalAggState, IOResult, SeekKey, SeekOp, SumAggState, Value,
ValueType,
},
util::{
cast_real_to_integer, cast_text_to_integer, cast_text_to_numeric, cast_text_to_real,
checked_cast_text_to_numeric, parse_schema_rows, RoundToPrecision,
@@ -3044,6 +3047,33 @@ pub fn op_decr_jump_zero(
Ok(InsnFunctionStepResult::Step)
}
fn apply_kbn_step(acc: &mut Value, r: f64, state: &mut SumAggState) {
let s = acc.as_float();
let t = s + r;
let correction = if s.abs() > r.abs() {
(s - t) + r
} else {
(r - t) + s
};
state.r_err += correction;
*acc = Value::Float(t);
}
// Add a (possibly large) integer to the running sum.
fn apply_kbn_step_int(acc: &mut Value, i: i64, state: &mut SumAggState) {
const THRESHOLD: i64 = 4503599627370496; // 2^52
if i <= -THRESHOLD || i >= THRESHOLD {
let i_sm = i % 16384;
let i_big = i - i_sm;
apply_kbn_step(acc, i_big as f64, state);
apply_kbn_step(acc, i_sm as f64, state);
} else {
apply_kbn_step(acc, i as f64, state);
}
}
pub fn op_agg_step(
program: &Program,
state: &mut ProgramState,
@@ -3065,12 +3095,14 @@ pub fn op_agg_step(
AggFunc::Avg => {
Register::Aggregate(AggContext::Avg(Value::Float(0.0), Value::Integer(0)))
}
AggFunc::Sum => Register::Aggregate(AggContext::Sum(Value::Null, 0.0)),
AggFunc::Sum => {
Register::Aggregate(AggContext::Sum(Value::Null, SumAggState::default()))
}
AggFunc::Total => {
// The result of total() is always a floating point value.
// No overflow error is ever raised if any prior input was a floating point value.
// Total() never throws an integer overflow.
Register::Aggregate(AggContext::Sum(Value::Float(0.0), 0.0))
Register::Aggregate(AggContext::Sum(Value::Float(0.0), SumAggState::default()))
}
AggFunc::Count | AggFunc::Count0 => {
Register::Aggregate(AggContext::Count(Value::Integer(0)))
@@ -3128,35 +3160,62 @@ pub fn op_agg_step(
state.registers[*acc_reg], *acc_reg
);
};
let AggContext::Sum(acc, c) = agg.borrow_mut() else {
let AggContext::Sum(acc, sum_state) = agg.borrow_mut() else {
unreachable!();
};
match col {
Register::Value(owned_value) => {
match owned_value {
ref v @ Value::Integer(_) | ref v @ Value::Float(_) => match acc {
Value::Null => *acc = owned_value.clone(),
Value::Float(_) => {
let x = acc.as_float();
let y = v.as_float() - *c;
let t = x + y;
*c = (t - x) - y;
*acc = Value::Float(t);
}
Value::Integer(_) => {
*acc += owned_value;
}
_ => unreachable!(),
},
Value::Null => {
// Null values are ignored in sum
// Ignore NULLs
}
_ => match acc {
Value::Null => *acc = Value::Float(0.0),
Value::Integer(i) => *acc = Value::Float(*i as f64 + 0.0),
Value::Float(_) => {}
Value::Integer(i) => match acc {
Value::Null => {
*acc = Value::Integer(i);
}
Value::Integer(acc_i) => {
match acc_i.checked_add(i) {
Some(sum) => *acc = Value::Integer(sum),
None => {
// Overflow -> switch to float with KBN summation
let acc_f = *acc_i as f64;
*acc = Value::Float(acc_f);
sum_state.approx = true;
sum_state.ovrfl = true;
apply_kbn_step_int(acc, i, sum_state);
}
}
}
Value::Float(_) => {
apply_kbn_step_int(acc, i, sum_state);
}
_ => unreachable!(),
},
Value::Float(f) => match acc {
Value::Null => {
*acc = Value::Float(f);
}
Value::Integer(i) => {
let i_f = *i as f64;
*acc = Value::Float(i_f);
sum_state.approx = true;
apply_kbn_step(acc, f, sum_state);
}
Value::Float(_) => {
sum_state.approx = true;
apply_kbn_step(acc, f, sum_state);
}
_ => unreachable!(),
},
_ => {
// If any input to sum() is neither an integer nor a NULL, then sum() returns a float
// https://sqlite.org/lang_aggfunc.html
sum_state.approx = true;
}
}
}
_ => unreachable!(),
@@ -3340,13 +3399,18 @@ pub fn op_agg_final(
state.registers[*register] = Register::Value(acc.clone());
}
AggFunc::Sum => {
let AggContext::Sum(acc, _) = agg.borrow_mut() else {
let AggContext::Sum(acc, sum_state) = agg.borrow_mut() else {
unreachable!();
};
let value = match acc {
Value::Null => Value::Null,
v @ Value::Integer(_) | v @ Value::Float(_) => v.clone(),
_ => unreachable!(),
Value::Null => match sum_state.approx {
true => Value::Float(0.0),
false => Value::Null,
},
Value::Integer(i) if !sum_state.approx && !sum_state.ovrfl => {
Value::Integer(*i)
}
_ => Value::Float(acc.as_float() + sum_state.r_err),
};
state.registers[*register] = Register::Value(value);
}