mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-20 07:25:14 +01:00
Merge 'Simplify sum() aggregation logic' from bit-aloo
This refactors AggContext::Sum by removing the extra bool flag and simplifying type handling during aggregation: Closes #2265
This commit is contained in:
@@ -498,7 +498,7 @@ impl Value {
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AggContext {
|
||||
Avg(Value, Value), // acc and count
|
||||
Sum(Value, bool), // acc and has_non_numeric
|
||||
Sum(Value),
|
||||
Count(Value),
|
||||
Max(Option<Value>),
|
||||
Min(Option<Value>),
|
||||
@@ -522,7 +522,7 @@ impl AggContext {
|
||||
pub fn final_value(&self) -> &Value {
|
||||
match self {
|
||||
Self::Avg(acc, _count) => acc,
|
||||
Self::Sum(acc, _) => acc,
|
||||
Self::Sum(acc) => acc,
|
||||
Self::Count(count) => count,
|
||||
Self::Max(max) => max.as_ref().unwrap_or(&NULL),
|
||||
Self::Min(min) => min.as_ref().unwrap_or(&NULL),
|
||||
@@ -596,7 +596,7 @@ impl PartialOrd<AggContext> for AggContext {
|
||||
fn partial_cmp(&self, other: &AggContext) -> Option<std::cmp::Ordering> {
|
||||
match (self, other) {
|
||||
(Self::Avg(a, _), Self::Avg(b, _)) => a.partial_cmp(b),
|
||||
(Self::Sum(a, _), Self::Sum(b, _)) => a.partial_cmp(b),
|
||||
(Self::Sum(a), Self::Sum(b)) => a.partial_cmp(b),
|
||||
(Self::Count(a), Self::Count(b)) => a.partial_cmp(b),
|
||||
(Self::Max(a), Self::Max(b)) => a.partial_cmp(b),
|
||||
(Self::Min(a), Self::Min(b)) => a.partial_cmp(b),
|
||||
|
||||
@@ -3065,12 +3065,12 @@ pub fn op_agg_step(
|
||||
AggFunc::Avg => {
|
||||
Register::Aggregate(AggContext::Avg(Value::Float(0.0), Value::Integer(0)))
|
||||
}
|
||||
AggFunc::Sum => Register::Aggregate(AggContext::Sum(Value::Null, false)),
|
||||
AggFunc::Sum => Register::Aggregate(AggContext::Sum(Value::Null)),
|
||||
AggFunc::Total => {
|
||||
// The result of total() is always a floating point value.
|
||||
// No overflow error is ever raised if any prior input was a floating point value.
|
||||
// Total() never throws an integer overflow.
|
||||
Register::Aggregate(AggContext::Sum(Value::Float(0.0), false))
|
||||
Register::Aggregate(AggContext::Sum(Value::Float(0.0)))
|
||||
}
|
||||
AggFunc::Count | AggFunc::Count0 => {
|
||||
Register::Aggregate(AggContext::Count(Value::Integer(0)))
|
||||
@@ -3128,32 +3128,25 @@ pub fn op_agg_step(
|
||||
state.registers[*acc_reg], *acc_reg
|
||||
);
|
||||
};
|
||||
let AggContext::Sum(acc, has_non_numeric) = agg.borrow_mut() else {
|
||||
let AggContext::Sum(acc) = agg.borrow_mut() else {
|
||||
unreachable!();
|
||||
};
|
||||
match col {
|
||||
Register::Value(owned_value) => {
|
||||
match owned_value {
|
||||
Value::Integer(_) | Value::Float(_) => {
|
||||
// Promote accumulator to float if mixing integer and float
|
||||
if matches!((&*acc, &owned_value), (Value::Integer(_), Value::Float(_)))
|
||||
|| matches!(
|
||||
(&*acc, &owned_value),
|
||||
(Value::Float(_), Value::Integer(_))
|
||||
)
|
||||
{
|
||||
if let Value::Integer(i) = acc {
|
||||
*acc = Value::Float(*i as f64);
|
||||
}
|
||||
}
|
||||
*acc += owned_value;
|
||||
}
|
||||
Value::Integer(_) | Value::Float(_) => match acc {
|
||||
Value::Null => *acc = owned_value.clone(),
|
||||
_ => *acc += owned_value,
|
||||
},
|
||||
Value::Null => {
|
||||
// Null values are ignored in sum
|
||||
}
|
||||
_ => {
|
||||
*has_non_numeric = true;
|
||||
}
|
||||
_ => match acc {
|
||||
Value::Null => *acc = Value::Float(0.0),
|
||||
Value::Integer(i) => *acc = Value::Float(*i as f64 + 0.0),
|
||||
Value::Float(_) => {}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
@@ -3337,44 +3330,25 @@ pub fn op_agg_final(
|
||||
state.registers[*register] = Register::Value(acc.clone());
|
||||
}
|
||||
AggFunc::Sum => {
|
||||
let AggContext::Sum(acc, has_non_numeric) = agg.borrow_mut() else {
|
||||
let AggContext::Sum(acc) = agg.borrow_mut() else {
|
||||
unreachable!();
|
||||
};
|
||||
let value = match acc {
|
||||
Value::Integer(i) => {
|
||||
if *has_non_numeric {
|
||||
Value::Float(*i as f64)
|
||||
} else {
|
||||
Value::Integer(*i)
|
||||
}
|
||||
}
|
||||
Value::Float(f) => Value::Float(*f),
|
||||
_ => {
|
||||
if *has_non_numeric {
|
||||
// Non-numeric values encountered
|
||||
Value::Float(0.0)
|
||||
} else {
|
||||
// Only NULL values encountered
|
||||
Value::Null
|
||||
}
|
||||
}
|
||||
Value::Null => Value::Null,
|
||||
v @ Value::Integer(_) | v @ Value::Float(_) => v.clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
state.registers[*register] = Register::Value(value);
|
||||
}
|
||||
AggFunc::Total => {
|
||||
let AggContext::Sum(acc, has_non_numeric) = agg.borrow_mut() else {
|
||||
let AggContext::Sum(acc) = agg.borrow_mut() else {
|
||||
unreachable!();
|
||||
};
|
||||
let value = match acc {
|
||||
Value::Integer(i) => {
|
||||
if *has_non_numeric {
|
||||
Value::Float(*i as f64)
|
||||
} else {
|
||||
Value::Integer(*i)
|
||||
}
|
||||
}
|
||||
Value::Null => Value::Float(0.0),
|
||||
Value::Integer(i) => Value::Float(*i as f64),
|
||||
Value::Float(f) => Value::Float(*f),
|
||||
_ => Value::Float(0.0),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
state.registers[*register] = Register::Value(value);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user