Merge 'fix avg aggregation' from Nikita Sivukhin

- ignore NULL rows as SQLite do
- emit NULL instead of NaN when no rows were aggregated
- adjust agg column alias name

Reviewed-by: Jussi Saurio <jussi.saurio@gmail.com>

Closes #3376
This commit is contained in:
Preston Thorpe
2025-09-26 09:59:50 -04:00
committed by GitHub
4 changed files with 92 additions and 13 deletions

View File

@@ -72,6 +72,37 @@ test('explicit connect', async () => {
expect(await db.prepare("SELECT 1 as x").all()).toEqual([{ x: 1 }]);
})
test('avg-bug', async () => {
const db = await connect(':memory:');
const create = db.prepare(`create table "aggregate_table" (
"id" integer primary key autoincrement not null,
"name" text not null,
"a" integer,
"b" integer,
"c" integer,
"null_only" integer
);`);
await create.run();
const insert = db.prepare(
`insert into "aggregate_table" ("id", "name", "a", "b", "c", "null_only") values (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null);`,
);
await insert.run(
'value 1', 5, 10, 20,
'value 1', 5, 20, 30,
'value 2', 10, 50, 60,
'value 3', 20, 20, null,
'value 4', null, 90, 120,
'value 5', 80, 10, null,
'value 6', null, null, 150,
);
expect(await db.prepare(`select avg("a") from "aggregate_table";`).get()).toEqual({ 'avg (aggregate_table.a)': 24 });
expect(await db.prepare(`select avg("null_only") from "aggregate_table";`).get()).toEqual({ 'avg (aggregate_table.null_only)': null });
expect(await db.prepare(`select avg(distinct "b") from "aggregate_table";`).get()).toEqual({ 'avg (DISTINCT aggregate_table.b)': 42.5 });
})
test('on-disk db', async () => {
const path = `test-${(Math.random() * 10000) | 0}.db`;
try {

View File

@@ -40,6 +40,7 @@ pub mod numeric;
mod numeric;
use crate::storage::checksum::CHECKSUM_REQUIRED_RESERVED_BYTES;
use crate::translate::display::PlanContext;
use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME;
#[cfg(all(feature = "fs", feature = "conn_raw_api"))]
use crate::types::{WalFrameInfo, WalState};
@@ -91,6 +92,7 @@ pub use storage::{
};
use tracing::{instrument, Level};
use turso_macros::match_ignore_ascii_case;
use turso_parser::ast::fmt::ToTokens;
use turso_parser::{ast, ast::Cmd, parser::Parser};
use types::IOResult;
pub use types::RefValue;
@@ -2562,7 +2564,11 @@ impl Statement {
let column = &self.program.result_columns.get(idx).expect("No column");
match column.name(&self.program.table_references) {
Some(name) => Cow::Borrowed(name),
None => Cow::Owned(column.expr.to_string()),
None => {
let tables = [&self.program.table_references];
let ctx = PlanContext(&tables);
Cow::Owned(column.expr.displayer(&ctx).to_string())
}
}
}
QueryMode::Explain => Cow::Borrowed(EXPLAIN_COLUMNS[idx]),

View File

@@ -3637,17 +3637,21 @@ pub fn op_agg_step(
match func {
AggFunc::Avg => {
let col = state.registers[*col].clone();
let Register::Aggregate(agg) = state.registers[*acc_reg].borrow_mut() else {
panic!(
"Unexpected value {:?} in AggStep at register {}",
state.registers[*acc_reg], *acc_reg
);
};
let AggContext::Avg(acc, count) = agg.borrow_mut() else {
unreachable!();
};
*acc = acc.exec_add(col.get_value());
*count += 1;
// > The avg() function returns the average value of all non-NULL X within a group
// https://sqlite.org/lang_aggfunc.html#avg
if !col.is_null() {
let Register::Aggregate(agg) = state.registers[*acc_reg].borrow_mut() else {
panic!(
"Unexpected value {:?} in AggStep at register {}",
state.registers[*acc_reg], *acc_reg
);
};
let AggContext::Avg(acc, count) = agg.borrow_mut() else {
unreachable!();
};
*acc = acc.exec_add(col.get_value());
*count += 1;
}
}
AggFunc::Sum | AggFunc::Total => {
let col = state.registers[*col].clone();
@@ -3915,7 +3919,11 @@ pub fn op_agg_final(
let AggContext::Avg(acc, count) = agg else {
unreachable!();
};
let acc = acc.clone() / count.clone();
let acc = if count.as_int() == Some(0) {
Value::Null
} else {
acc.clone() / count.clone()
};
state.registers[dest_reg] = Register::Value(acc);
}
AggFunc::Sum => {

View File

@@ -750,3 +750,37 @@ fn test_cte_alias() -> anyhow::Result<()> {
}
Ok(())
}
#[test]
fn test_avg_agg() -> anyhow::Result<()> {
let tmp_db = TempDatabase::new_with_rusqlite("create table t (x, y);", false);
let conn = tmp_db.connect_limbo();
conn.execute("insert into t values (1, null), (2, null), (3, null), (null, null), (4, null)")?;
let mut rows = Vec::new();
let mut stmt = conn.prepare("select avg(x), avg(y) from t")?;
loop {
match stmt.step()? {
StepResult::Row => {
let row = stmt.row().unwrap();
rows.push(row.get_values().cloned().collect::<Vec<_>>());
}
StepResult::Done => break,
StepResult::IO => stmt.run_once()?,
_ => panic!("Unexpected step result"),
}
}
assert_eq!(stmt.num_columns(), 2);
assert_eq!(stmt.get_column_name(0), "avg (t.x)");
assert_eq!(stmt.get_column_name(1), "avg (t.y)");
assert_eq!(
rows,
vec![vec![
turso_core::Value::Float((1.0 + 2.0 + 3.0 + 4.0) / (4.0)),
turso_core::Value::Null
]]
);
Ok(())
}