diff --git a/bindings/javascript/packages/native/promise.test.ts b/bindings/javascript/packages/native/promise.test.ts index 82d9e1064..422a6f56a 100644 --- a/bindings/javascript/packages/native/promise.test.ts +++ b/bindings/javascript/packages/native/promise.test.ts @@ -72,6 +72,37 @@ test('explicit connect', async () => { expect(await db.prepare("SELECT 1 as x").all()).toEqual([{ x: 1 }]); }) +test('avg-bug', async () => { + const db = await connect(':memory:'); + const create = db.prepare(`create table "aggregate_table" ( + "id" integer primary key autoincrement not null, + "name" text not null, + "a" integer, + "b" integer, + "c" integer, + "null_only" integer + );`); + + await create.run(); + const insert = db.prepare( + `insert into "aggregate_table" ("id", "name", "a", "b", "c", "null_only") values (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null);`, + ); + + await insert.run( + 'value 1', 5, 10, 20, + 'value 1', 5, 20, 30, + 'value 2', 10, 50, 60, + 'value 3', 20, 20, null, + 'value 4', null, 90, 120, + 'value 5', 80, 10, null, + 'value 6', null, null, 150, + ); + + expect(await db.prepare(`select avg("a") from "aggregate_table";`).get()).toEqual({ 'avg ("a")': 24 }); + expect(await db.prepare(`select avg("null_only") from "aggregate_table";`).get()).toEqual({ 'avg ("null_only")': null }); + expect(await db.prepare(`select avg(distinct "b") from "aggregate_table";`).get()).toEqual({ 'avg (DISTINCT "b")': 42.5 }); +}) + test('on-disk db', async () => { const path = `test-${(Math.random() * 10000) | 0}.db`; try { diff --git a/core/translate/select.rs b/core/translate/select.rs index 3b305ba12..bd6d1a2b8 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -369,6 +369,15 @@ fn prepare_one_select_plan( } } ResultColumn::Expr(ref mut expr, maybe_alias) => { + let alias = if let Some(alias) = maybe_alias { + match alias { + ast::As::Elided(alias) => alias.as_str().to_string(), + ast::As::As(alias) => alias.as_str().to_string(), + } + } else { + // we always emit alias - otherwise user will see very confusing column name (e.g. avg(t0.c1)) + expr.as_ref().to_string() + }; bind_and_rewrite_expr( expr, Some(&mut plan.table_references), @@ -385,10 +394,7 @@ fn prepare_one_select_plan( Some(&mut windows), )?; plan.result_columns.push(ResultSetColumn { - alias: maybe_alias.as_ref().map(|alias| match alias { - ast::As::Elided(alias) => alias.as_str().to_string(), - ast::As::As(alias) => alias.as_str().to_string(), - }), + alias: Some(alias), expr: expr.as_ref().clone(), contains_aggregates, }); diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 4b1bdd4d1..fc55eee0e 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -3637,17 +3637,21 @@ pub fn op_agg_step( match func { AggFunc::Avg => { let col = state.registers[*col].clone(); - let Register::Aggregate(agg) = state.registers[*acc_reg].borrow_mut() else { - panic!( - "Unexpected value {:?} in AggStep at register {}", - state.registers[*acc_reg], *acc_reg - ); - }; - let AggContext::Avg(acc, count) = agg.borrow_mut() else { - unreachable!(); - }; - *acc = acc.exec_add(col.get_value()); - *count += 1; + // > The avg() function returns the average value of all non-NULL X within a group + // https://sqlite.org/lang_aggfunc.html#avg + if !col.is_null() { + let Register::Aggregate(agg) = state.registers[*acc_reg].borrow_mut() else { + panic!( + "Unexpected value {:?} in AggStep at register {}", + state.registers[*acc_reg], *acc_reg + ); + }; + let AggContext::Avg(acc, count) = agg.borrow_mut() else { + unreachable!(); + }; + *acc = acc.exec_add(col.get_value()); + *count += 1; + } } AggFunc::Sum | AggFunc::Total => { let col = state.registers[*col].clone(); @@ -3915,7 +3919,11 @@ pub fn op_agg_final( let AggContext::Avg(acc, count) = agg else { unreachable!(); }; - let acc = acc.clone() / count.clone(); + let acc = if count.as_int() == Some(0) { + Value::Null + } else { + acc.clone() / count.clone() + }; state.registers[dest_reg] = Register::Value(acc); } AggFunc::Sum => { diff --git a/tests/integration/query_processing/test_read_path.rs b/tests/integration/query_processing/test_read_path.rs index 2285b33f8..3a1b3e52a 100644 --- a/tests/integration/query_processing/test_read_path.rs +++ b/tests/integration/query_processing/test_read_path.rs @@ -750,3 +750,37 @@ fn test_cte_alias() -> anyhow::Result<()> { } Ok(()) } + +#[test] +fn test_avg_agg() -> anyhow::Result<()> { + let tmp_db = TempDatabase::new_with_rusqlite("create table t (x, y);", false); + let conn = tmp_db.connect_limbo(); + conn.execute("insert into t values (1, null), (2, null), (3, null), (null, null), (4, null)")?; + let mut rows = Vec::new(); + let mut stmt = conn.prepare("select avg(x), avg(y) from t")?; + loop { + match stmt.step()? { + StepResult::Row => { + let row = stmt.row().unwrap(); + rows.push(row.get_values().cloned().collect::>()); + } + StepResult::Done => break, + StepResult::IO => stmt.run_once()?, + _ => panic!("Unexpected step result"), + } + } + + assert_eq!(stmt.num_columns(), 2); + assert_eq!(stmt.get_column_name(0), "avg (x)"); + assert_eq!(stmt.get_column_name(1), "avg (y)"); + + assert_eq!( + rows, + vec![vec![ + turso_core::Value::Float((1.0 + 2.0 + 3.0 + 4.0) / (4.0)), + turso_core::Value::Null + ]] + ); + + Ok(()) +}