fix: SUM returns float for mixed numeric/non-numeric types

This commit is contained in:
Axel
2025-07-20 06:57:25 +02:00
parent d1fdc7dbc8
commit 73389abf09
2 changed files with 30 additions and 9 deletions

View File

@@ -498,7 +498,7 @@ impl Value {
#[derive(Debug, Clone, PartialEq)]
pub enum AggContext {
Avg(Value, Value), // acc and count
Sum(Value),
Sum(Value, bool), // acc and has_non_numeric
Count(Value),
Max(Option<Value>),
Min(Option<Value>),
@@ -522,7 +522,7 @@ impl AggContext {
pub fn final_value(&self) -> &Value {
match self {
Self::Avg(acc, _count) => acc,
Self::Sum(acc) => acc,
Self::Sum(acc, _) => acc,
Self::Count(count) => count,
Self::Max(max) => max.as_ref().unwrap_or(&NULL),
Self::Min(min) => min.as_ref().unwrap_or(&NULL),
@@ -596,7 +596,7 @@ impl PartialOrd<AggContext> for AggContext {
fn partial_cmp(&self, other: &AggContext) -> Option<std::cmp::Ordering> {
match (self, other) {
(Self::Avg(a, _), Self::Avg(b, _)) => a.partial_cmp(b),
(Self::Sum(a), Self::Sum(b)) => a.partial_cmp(b),
(Self::Sum(a, _), Self::Sum(b, _)) => a.partial_cmp(b),
(Self::Count(a), Self::Count(b)) => a.partial_cmp(b),
(Self::Max(a), Self::Max(b)) => a.partial_cmp(b),
(Self::Min(a), Self::Min(b)) => a.partial_cmp(b),

View File

@@ -3046,12 +3046,12 @@ pub fn op_agg_step(
AggFunc::Avg => {
Register::Aggregate(AggContext::Avg(Value::Float(0.0), Value::Integer(0)))
}
AggFunc::Sum => Register::Aggregate(AggContext::Sum(Value::Null)),
AggFunc::Sum => Register::Aggregate(AggContext::Sum(Value::Null, false)),
AggFunc::Total => {
// The result of total() is always a floating point value.
// No overflow error is ever raised if any prior input was a floating point value.
// Total() never throws an integer overflow.
Register::Aggregate(AggContext::Sum(Value::Float(0.0)))
Register::Aggregate(AggContext::Sum(Value::Float(0.0), false))
}
AggFunc::Count | AggFunc::Count0 => {
Register::Aggregate(AggContext::Count(Value::Integer(0)))
@@ -3109,12 +3109,27 @@ pub fn op_agg_step(
state.registers[*acc_reg], *acc_reg
);
};
let AggContext::Sum(acc) = agg.borrow_mut() else {
let AggContext::Sum(acc, has_non_numeric) = agg.borrow_mut() else {
unreachable!();
};
match col {
Register::Value(owned_value) => {
*acc += owned_value;
match &owned_value {
Value::Integer(_) | Value::Float(_) => {
// Convert to float if we have mixed numeric types
if let (Value::Integer(i), Value::Float(_)) =
(acc.clone(), &owned_value)
{
*acc = Value::Float(i as f64);
*has_non_numeric = true;
}
*acc += owned_value;
}
_ => {
// Skip and mark non-numeric values (NULL, text, blob)
*has_non_numeric = true;
}
}
}
_ => unreachable!(),
}
@@ -3297,11 +3312,17 @@ pub fn op_agg_final(
state.registers[*register] = Register::Value(acc.clone());
}
AggFunc::Sum | AggFunc::Total => {
let AggContext::Sum(acc) = agg.borrow_mut() else {
let AggContext::Sum(acc, has_non_numeric) = agg.borrow_mut() else {
unreachable!();
};
let value = match acc {
Value::Integer(i) => Value::Integer(*i),
Value::Integer(i) => {
if *has_non_numeric {
Value::Float(*i as f64)
} else {
Value::Integer(*i)
}
}
Value::Float(f) => Value::Float(*f),
_ => Value::Float(0.0),
};