From 52f3216211c6c8bc24e3fa26ac52258da624ada3 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 26 Sep 2025 17:11:06 +0400 Subject: [PATCH] fix avg aggregation - ignore NULL rows as SQLite do - emit NULL instead of NaN when no rows were aggregated - adjust agg column alias name --- .../packages/native/promise.test.ts | 31 +++++++++++++++++ core/translate/select.rs | 14 +++++--- core/vdbe/execute.rs | 32 ++++++++++------- .../query_processing/test_read_path.rs | 34 +++++++++++++++++++ 4 files changed, 95 insertions(+), 16 deletions(-) diff --git a/bindings/javascript/packages/native/promise.test.ts b/bindings/javascript/packages/native/promise.test.ts index 82d9e1064..422a6f56a 100644 --- a/bindings/javascript/packages/native/promise.test.ts +++ b/bindings/javascript/packages/native/promise.test.ts @@ -72,6 +72,37 @@ test('explicit connect', async () => { expect(await db.prepare("SELECT 1 as x").all()).toEqual([{ x: 1 }]); }) +test('avg-bug', async () => { + const db = await connect(':memory:'); + const create = db.prepare(`create table "aggregate_table" ( + "id" integer primary key autoincrement not null, + "name" text not null, + "a" integer, + "b" integer, + "c" integer, + "null_only" integer + );`); + + await create.run(); + const insert = db.prepare( + `insert into "aggregate_table" ("id", "name", "a", "b", "c", "null_only") values (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null);`, + ); + + await insert.run( + 'value 1', 5, 10, 20, + 'value 1', 5, 20, 30, + 'value 2', 10, 50, 60, + 'value 3', 20, 20, null, + 'value 4', null, 90, 120, + 'value 5', 80, 10, null, + 'value 6', null, null, 150, + ); + + expect(await db.prepare(`select avg("a") from "aggregate_table";`).get()).toEqual({ 'avg ("a")': 24 }); + expect(await db.prepare(`select avg("null_only") from "aggregate_table";`).get()).toEqual({ 'avg ("null_only")': null }); + expect(await db.prepare(`select avg(distinct "b") from "aggregate_table";`).get()).toEqual({ 'avg (DISTINCT "b")': 42.5 }); +}) + test('on-disk db', async () => { const path = `test-${(Math.random() * 10000) | 0}.db`; try { diff --git a/core/translate/select.rs b/core/translate/select.rs index 3b305ba12..bd6d1a2b8 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -369,6 +369,15 @@ fn prepare_one_select_plan( } } ResultColumn::Expr(ref mut expr, maybe_alias) => { + let alias = if let Some(alias) = maybe_alias { + match alias { + ast::As::Elided(alias) => alias.as_str().to_string(), + ast::As::As(alias) => alias.as_str().to_string(), + } + } else { + // we always emit alias - otherwise user will see very confusing column name (e.g. avg(t0.c1)) + expr.as_ref().to_string() + }; bind_and_rewrite_expr( expr, Some(&mut plan.table_references), @@ -385,10 +394,7 @@ fn prepare_one_select_plan( Some(&mut windows), )?; plan.result_columns.push(ResultSetColumn { - alias: maybe_alias.as_ref().map(|alias| match alias { - ast::As::Elided(alias) => alias.as_str().to_string(), - ast::As::As(alias) => alias.as_str().to_string(), - }), + alias: Some(alias), expr: expr.as_ref().clone(), contains_aggregates, }); diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 4b1bdd4d1..fc55eee0e 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -3637,17 +3637,21 @@ pub fn op_agg_step( match func { AggFunc::Avg => { let col = state.registers[*col].clone(); - let Register::Aggregate(agg) = state.registers[*acc_reg].borrow_mut() else { - panic!( - "Unexpected value {:?} in AggStep at register {}", - state.registers[*acc_reg], *acc_reg - ); - }; - let AggContext::Avg(acc, count) = agg.borrow_mut() else { - unreachable!(); - }; - *acc = acc.exec_add(col.get_value()); - *count += 1; + // > The avg() function returns the average value of all non-NULL X within a group + // https://sqlite.org/lang_aggfunc.html#avg + if !col.is_null() { + let Register::Aggregate(agg) = state.registers[*acc_reg].borrow_mut() else { + panic!( + "Unexpected value {:?} in AggStep at register {}", + state.registers[*acc_reg], *acc_reg + ); + }; + let AggContext::Avg(acc, count) = agg.borrow_mut() else { + unreachable!(); + }; + *acc = acc.exec_add(col.get_value()); + *count += 1; + } } AggFunc::Sum | AggFunc::Total => { let col = state.registers[*col].clone(); @@ -3915,7 +3919,11 @@ pub fn op_agg_final( let AggContext::Avg(acc, count) = agg else { unreachable!(); }; - let acc = acc.clone() / count.clone(); + let acc = if count.as_int() == Some(0) { + Value::Null + } else { + acc.clone() / count.clone() + }; state.registers[dest_reg] = Register::Value(acc); } AggFunc::Sum => { diff --git a/tests/integration/query_processing/test_read_path.rs b/tests/integration/query_processing/test_read_path.rs index 2285b33f8..3a1b3e52a 100644 --- a/tests/integration/query_processing/test_read_path.rs +++ b/tests/integration/query_processing/test_read_path.rs @@ -750,3 +750,37 @@ fn test_cte_alias() -> anyhow::Result<()> { } Ok(()) } + +#[test] +fn test_avg_agg() -> anyhow::Result<()> { + let tmp_db = TempDatabase::new_with_rusqlite("create table t (x, y);", false); + let conn = tmp_db.connect_limbo(); + conn.execute("insert into t values (1, null), (2, null), (3, null), (null, null), (4, null)")?; + let mut rows = Vec::new(); + let mut stmt = conn.prepare("select avg(x), avg(y) from t")?; + loop { + match stmt.step()? { + StepResult::Row => { + let row = stmt.row().unwrap(); + rows.push(row.get_values().cloned().collect::>()); + } + StepResult::Done => break, + StepResult::IO => stmt.run_once()?, + _ => panic!("Unexpected step result"), + } + } + + assert_eq!(stmt.num_columns(), 2); + assert_eq!(stmt.get_column_name(0), "avg (x)"); + assert_eq!(stmt.get_column_name(1), "avg (y)"); + + assert_eq!( + rows, + vec![vec![ + turso_core::Value::Float((1.0 + 2.0 + 3.0 + 4.0) / (4.0)), + turso_core::Value::Null + ]] + ); + + Ok(()) +}