From 52f3216211c6c8bc24e3fa26ac52258da624ada3 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 26 Sep 2025 17:11:06 +0400 Subject: [PATCH 1/3] fix avg aggregation - ignore NULL rows as SQLite do - emit NULL instead of NaN when no rows were aggregated - adjust agg column alias name --- .../packages/native/promise.test.ts | 31 +++++++++++++++++ core/translate/select.rs | 14 +++++--- core/vdbe/execute.rs | 32 ++++++++++------- .../query_processing/test_read_path.rs | 34 +++++++++++++++++++ 4 files changed, 95 insertions(+), 16 deletions(-) diff --git a/bindings/javascript/packages/native/promise.test.ts b/bindings/javascript/packages/native/promise.test.ts index 82d9e1064..422a6f56a 100644 --- a/bindings/javascript/packages/native/promise.test.ts +++ b/bindings/javascript/packages/native/promise.test.ts @@ -72,6 +72,37 @@ test('explicit connect', async () => { expect(await db.prepare("SELECT 1 as x").all()).toEqual([{ x: 1 }]); }) +test('avg-bug', async () => { + const db = await connect(':memory:'); + const create = db.prepare(`create table "aggregate_table" ( + "id" integer primary key autoincrement not null, + "name" text not null, + "a" integer, + "b" integer, + "c" integer, + "null_only" integer + );`); + + await create.run(); + const insert = db.prepare( + `insert into "aggregate_table" ("id", "name", "a", "b", "c", "null_only") values (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null);`, + ); + + await insert.run( + 'value 1', 5, 10, 20, + 'value 1', 5, 20, 30, + 'value 2', 10, 50, 60, + 'value 3', 20, 20, null, + 'value 4', null, 90, 120, + 'value 5', 80, 10, null, + 'value 6', null, null, 150, + ); + + expect(await db.prepare(`select avg("a") from "aggregate_table";`).get()).toEqual({ 'avg ("a")': 24 }); + expect(await db.prepare(`select avg("null_only") from "aggregate_table";`).get()).toEqual({ 'avg ("null_only")': null }); + expect(await db.prepare(`select avg(distinct "b") from "aggregate_table";`).get()).toEqual({ 'avg (DISTINCT "b")': 42.5 }); +}) + test('on-disk db', async () => { const path = `test-${(Math.random() * 10000) | 0}.db`; try { diff --git a/core/translate/select.rs b/core/translate/select.rs index 3b305ba12..bd6d1a2b8 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -369,6 +369,15 @@ fn prepare_one_select_plan( } } ResultColumn::Expr(ref mut expr, maybe_alias) => { + let alias = if let Some(alias) = maybe_alias { + match alias { + ast::As::Elided(alias) => alias.as_str().to_string(), + ast::As::As(alias) => alias.as_str().to_string(), + } + } else { + // we always emit alias - otherwise user will see very confusing column name (e.g. avg(t0.c1)) + expr.as_ref().to_string() + }; bind_and_rewrite_expr( expr, Some(&mut plan.table_references), @@ -385,10 +394,7 @@ fn prepare_one_select_plan( Some(&mut windows), )?; plan.result_columns.push(ResultSetColumn { - alias: maybe_alias.as_ref().map(|alias| match alias { - ast::As::Elided(alias) => alias.as_str().to_string(), - ast::As::As(alias) => alias.as_str().to_string(), - }), + alias: Some(alias), expr: expr.as_ref().clone(), contains_aggregates, }); diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 4b1bdd4d1..fc55eee0e 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -3637,17 +3637,21 @@ pub fn op_agg_step( match func { AggFunc::Avg => { let col = state.registers[*col].clone(); - let Register::Aggregate(agg) = state.registers[*acc_reg].borrow_mut() else { - panic!( - "Unexpected value {:?} in AggStep at register {}", - state.registers[*acc_reg], *acc_reg - ); - }; - let AggContext::Avg(acc, count) = agg.borrow_mut() else { - unreachable!(); - }; - *acc = acc.exec_add(col.get_value()); - *count += 1; + // > The avg() function returns the average value of all non-NULL X within a group + // https://sqlite.org/lang_aggfunc.html#avg + if !col.is_null() { + let Register::Aggregate(agg) = state.registers[*acc_reg].borrow_mut() else { + panic!( + "Unexpected value {:?} in AggStep at register {}", + state.registers[*acc_reg], *acc_reg + ); + }; + let AggContext::Avg(acc, count) = agg.borrow_mut() else { + unreachable!(); + }; + *acc = acc.exec_add(col.get_value()); + *count += 1; + } } AggFunc::Sum | AggFunc::Total => { let col = state.registers[*col].clone(); @@ -3915,7 +3919,11 @@ pub fn op_agg_final( let AggContext::Avg(acc, count) = agg else { unreachable!(); }; - let acc = acc.clone() / count.clone(); + let acc = if count.as_int() == Some(0) { + Value::Null + } else { + acc.clone() / count.clone() + }; state.registers[dest_reg] = Register::Value(acc); } AggFunc::Sum => { diff --git a/tests/integration/query_processing/test_read_path.rs b/tests/integration/query_processing/test_read_path.rs index 2285b33f8..3a1b3e52a 100644 --- a/tests/integration/query_processing/test_read_path.rs +++ b/tests/integration/query_processing/test_read_path.rs @@ -750,3 +750,37 @@ fn test_cte_alias() -> anyhow::Result<()> { } Ok(()) } + +#[test] +fn test_avg_agg() -> anyhow::Result<()> { + let tmp_db = TempDatabase::new_with_rusqlite("create table t (x, y);", false); + let conn = tmp_db.connect_limbo(); + conn.execute("insert into t values (1, null), (2, null), (3, null), (null, null), (4, null)")?; + let mut rows = Vec::new(); + let mut stmt = conn.prepare("select avg(x), avg(y) from t")?; + loop { + match stmt.step()? { + StepResult::Row => { + let row = stmt.row().unwrap(); + rows.push(row.get_values().cloned().collect::>()); + } + StepResult::Done => break, + StepResult::IO => stmt.run_once()?, + _ => panic!("Unexpected step result"), + } + } + + assert_eq!(stmt.num_columns(), 2); + assert_eq!(stmt.get_column_name(0), "avg (x)"); + assert_eq!(stmt.get_column_name(1), "avg (y)"); + + assert_eq!( + rows, + vec![vec![ + turso_core::Value::Float((1.0 + 2.0 + 3.0 + 4.0) / (4.0)), + turso_core::Value::Null + ]] + ); + + Ok(()) +} From 5b5379d0788e93e5564fe968b4a5d9cc18233a3a Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 26 Sep 2025 17:40:41 +0400 Subject: [PATCH 2/3] propagate context to stringifier to properly derive column names --- core/lib.rs | 8 +++++++- core/translate/select.rs | 14 ++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/lib.rs b/core/lib.rs index b3e1c6a0d..8784f8aee 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -40,6 +40,7 @@ pub mod numeric; mod numeric; use crate::storage::checksum::CHECKSUM_REQUIRED_RESERVED_BYTES; +use crate::translate::display::PlanContext; use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME; #[cfg(all(feature = "fs", feature = "conn_raw_api"))] use crate::types::{WalFrameInfo, WalState}; @@ -91,6 +92,7 @@ pub use storage::{ }; use tracing::{instrument, Level}; use turso_macros::match_ignore_ascii_case; +use turso_parser::ast::fmt::ToTokens; use turso_parser::{ast, ast::Cmd, parser::Parser}; use types::IOResult; pub use types::RefValue; @@ -2562,7 +2564,11 @@ impl Statement { let column = &self.program.result_columns.get(idx).expect("No column"); match column.name(&self.program.table_references) { Some(name) => Cow::Borrowed(name), - None => Cow::Owned(column.expr.to_string()), + None => { + let tables = [&self.program.table_references]; + let ctx = PlanContext(&tables); + Cow::Owned(column.expr.displayer(&ctx).to_string()) + } } } QueryMode::Explain => Cow::Borrowed(EXPLAIN_COLUMNS[idx]), diff --git a/core/translate/select.rs b/core/translate/select.rs index bd6d1a2b8..3b305ba12 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -369,15 +369,6 @@ fn prepare_one_select_plan( } } ResultColumn::Expr(ref mut expr, maybe_alias) => { - let alias = if let Some(alias) = maybe_alias { - match alias { - ast::As::Elided(alias) => alias.as_str().to_string(), - ast::As::As(alias) => alias.as_str().to_string(), - } - } else { - // we always emit alias - otherwise user will see very confusing column name (e.g. avg(t0.c1)) - expr.as_ref().to_string() - }; bind_and_rewrite_expr( expr, Some(&mut plan.table_references), @@ -394,7 +385,10 @@ fn prepare_one_select_plan( Some(&mut windows), )?; plan.result_columns.push(ResultSetColumn { - alias: Some(alias), + alias: maybe_alias.as_ref().map(|alias| match alias { + ast::As::Elided(alias) => alias.as_str().to_string(), + ast::As::As(alias) => alias.as_str().to_string(), + }), expr: expr.as_ref().clone(), contains_aggregates, }); From a0c47b98b880409a77daf6d428cdd45da7f816bb Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 26 Sep 2025 17:41:13 +0400 Subject: [PATCH 3/3] fix test --- bindings/javascript/packages/native/promise.test.ts | 6 +++--- tests/integration/query_processing/test_read_path.rs | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bindings/javascript/packages/native/promise.test.ts b/bindings/javascript/packages/native/promise.test.ts index 422a6f56a..190fa23c5 100644 --- a/bindings/javascript/packages/native/promise.test.ts +++ b/bindings/javascript/packages/native/promise.test.ts @@ -98,9 +98,9 @@ test('avg-bug', async () => { 'value 6', null, null, 150, ); - expect(await db.prepare(`select avg("a") from "aggregate_table";`).get()).toEqual({ 'avg ("a")': 24 }); - expect(await db.prepare(`select avg("null_only") from "aggregate_table";`).get()).toEqual({ 'avg ("null_only")': null }); - expect(await db.prepare(`select avg(distinct "b") from "aggregate_table";`).get()).toEqual({ 'avg (DISTINCT "b")': 42.5 }); + expect(await db.prepare(`select avg("a") from "aggregate_table";`).get()).toEqual({ 'avg (aggregate_table.a)': 24 }); + expect(await db.prepare(`select avg("null_only") from "aggregate_table";`).get()).toEqual({ 'avg (aggregate_table.null_only)': null }); + expect(await db.prepare(`select avg(distinct "b") from "aggregate_table";`).get()).toEqual({ 'avg (DISTINCT aggregate_table.b)': 42.5 }); }) test('on-disk db', async () => { diff --git a/tests/integration/query_processing/test_read_path.rs b/tests/integration/query_processing/test_read_path.rs index 3a1b3e52a..452ca1c85 100644 --- a/tests/integration/query_processing/test_read_path.rs +++ b/tests/integration/query_processing/test_read_path.rs @@ -771,8 +771,8 @@ fn test_avg_agg() -> anyhow::Result<()> { } assert_eq!(stmt.num_columns(), 2); - assert_eq!(stmt.get_column_name(0), "avg (x)"); - assert_eq!(stmt.get_column_name(1), "avg (y)"); + assert_eq!(stmt.get_column_name(0), "avg (t.x)"); + assert_eq!(stmt.get_column_name(1), "avg (t.y)"); assert_eq!( rows,