mirror of
https://github.com/aljazceru/turso.git
synced 2026-02-23 00:45:37 +01:00
Merge 'fix avg aggregation' from Nikita Sivukhin
- ignore NULL rows as SQLite do - emit NULL instead of NaN when no rows were aggregated - adjust agg column alias name Reviewed-by: Jussi Saurio <jussi.saurio@gmail.com> Closes #3376
This commit is contained in:
@@ -72,6 +72,37 @@ test('explicit connect', async () => {
|
||||
expect(await db.prepare("SELECT 1 as x").all()).toEqual([{ x: 1 }]);
|
||||
})
|
||||
|
||||
test('avg-bug', async () => {
|
||||
const db = await connect(':memory:');
|
||||
const create = db.prepare(`create table "aggregate_table" (
|
||||
"id" integer primary key autoincrement not null,
|
||||
"name" text not null,
|
||||
"a" integer,
|
||||
"b" integer,
|
||||
"c" integer,
|
||||
"null_only" integer
|
||||
);`);
|
||||
|
||||
await create.run();
|
||||
const insert = db.prepare(
|
||||
`insert into "aggregate_table" ("id", "name", "a", "b", "c", "null_only") values (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null), (null, ?, ?, ?, ?, null);`,
|
||||
);
|
||||
|
||||
await insert.run(
|
||||
'value 1', 5, 10, 20,
|
||||
'value 1', 5, 20, 30,
|
||||
'value 2', 10, 50, 60,
|
||||
'value 3', 20, 20, null,
|
||||
'value 4', null, 90, 120,
|
||||
'value 5', 80, 10, null,
|
||||
'value 6', null, null, 150,
|
||||
);
|
||||
|
||||
expect(await db.prepare(`select avg("a") from "aggregate_table";`).get()).toEqual({ 'avg (aggregate_table.a)': 24 });
|
||||
expect(await db.prepare(`select avg("null_only") from "aggregate_table";`).get()).toEqual({ 'avg (aggregate_table.null_only)': null });
|
||||
expect(await db.prepare(`select avg(distinct "b") from "aggregate_table";`).get()).toEqual({ 'avg (DISTINCT aggregate_table.b)': 42.5 });
|
||||
})
|
||||
|
||||
test('on-disk db', async () => {
|
||||
const path = `test-${(Math.random() * 10000) | 0}.db`;
|
||||
try {
|
||||
|
||||
@@ -40,6 +40,7 @@ pub mod numeric;
|
||||
mod numeric;
|
||||
|
||||
use crate::storage::checksum::CHECKSUM_REQUIRED_RESERVED_BYTES;
|
||||
use crate::translate::display::PlanContext;
|
||||
use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME;
|
||||
#[cfg(all(feature = "fs", feature = "conn_raw_api"))]
|
||||
use crate::types::{WalFrameInfo, WalState};
|
||||
@@ -91,6 +92,7 @@ pub use storage::{
|
||||
};
|
||||
use tracing::{instrument, Level};
|
||||
use turso_macros::match_ignore_ascii_case;
|
||||
use turso_parser::ast::fmt::ToTokens;
|
||||
use turso_parser::{ast, ast::Cmd, parser::Parser};
|
||||
use types::IOResult;
|
||||
pub use types::RefValue;
|
||||
@@ -2562,7 +2564,11 @@ impl Statement {
|
||||
let column = &self.program.result_columns.get(idx).expect("No column");
|
||||
match column.name(&self.program.table_references) {
|
||||
Some(name) => Cow::Borrowed(name),
|
||||
None => Cow::Owned(column.expr.to_string()),
|
||||
None => {
|
||||
let tables = [&self.program.table_references];
|
||||
let ctx = PlanContext(&tables);
|
||||
Cow::Owned(column.expr.displayer(&ctx).to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
QueryMode::Explain => Cow::Borrowed(EXPLAIN_COLUMNS[idx]),
|
||||
|
||||
@@ -3637,17 +3637,21 @@ pub fn op_agg_step(
|
||||
match func {
|
||||
AggFunc::Avg => {
|
||||
let col = state.registers[*col].clone();
|
||||
let Register::Aggregate(agg) = state.registers[*acc_reg].borrow_mut() else {
|
||||
panic!(
|
||||
"Unexpected value {:?} in AggStep at register {}",
|
||||
state.registers[*acc_reg], *acc_reg
|
||||
);
|
||||
};
|
||||
let AggContext::Avg(acc, count) = agg.borrow_mut() else {
|
||||
unreachable!();
|
||||
};
|
||||
*acc = acc.exec_add(col.get_value());
|
||||
*count += 1;
|
||||
// > The avg() function returns the average value of all non-NULL X within a group
|
||||
// https://sqlite.org/lang_aggfunc.html#avg
|
||||
if !col.is_null() {
|
||||
let Register::Aggregate(agg) = state.registers[*acc_reg].borrow_mut() else {
|
||||
panic!(
|
||||
"Unexpected value {:?} in AggStep at register {}",
|
||||
state.registers[*acc_reg], *acc_reg
|
||||
);
|
||||
};
|
||||
let AggContext::Avg(acc, count) = agg.borrow_mut() else {
|
||||
unreachable!();
|
||||
};
|
||||
*acc = acc.exec_add(col.get_value());
|
||||
*count += 1;
|
||||
}
|
||||
}
|
||||
AggFunc::Sum | AggFunc::Total => {
|
||||
let col = state.registers[*col].clone();
|
||||
@@ -3915,7 +3919,11 @@ pub fn op_agg_final(
|
||||
let AggContext::Avg(acc, count) = agg else {
|
||||
unreachable!();
|
||||
};
|
||||
let acc = acc.clone() / count.clone();
|
||||
let acc = if count.as_int() == Some(0) {
|
||||
Value::Null
|
||||
} else {
|
||||
acc.clone() / count.clone()
|
||||
};
|
||||
state.registers[dest_reg] = Register::Value(acc);
|
||||
}
|
||||
AggFunc::Sum => {
|
||||
|
||||
@@ -750,3 +750,37 @@ fn test_cte_alias() -> anyhow::Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_agg() -> anyhow::Result<()> {
|
||||
let tmp_db = TempDatabase::new_with_rusqlite("create table t (x, y);", false);
|
||||
let conn = tmp_db.connect_limbo();
|
||||
conn.execute("insert into t values (1, null), (2, null), (3, null), (null, null), (4, null)")?;
|
||||
let mut rows = Vec::new();
|
||||
let mut stmt = conn.prepare("select avg(x), avg(y) from t")?;
|
||||
loop {
|
||||
match stmt.step()? {
|
||||
StepResult::Row => {
|
||||
let row = stmt.row().unwrap();
|
||||
rows.push(row.get_values().cloned().collect::<Vec<_>>());
|
||||
}
|
||||
StepResult::Done => break,
|
||||
StepResult::IO => stmt.run_once()?,
|
||||
_ => panic!("Unexpected step result"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(stmt.num_columns(), 2);
|
||||
assert_eq!(stmt.get_column_name(0), "avg (t.x)");
|
||||
assert_eq!(stmt.get_column_name(1), "avg (t.y)");
|
||||
|
||||
assert_eq!(
|
||||
rows,
|
||||
vec![vec![
|
||||
turso_core::Value::Float((1.0 + 2.0 + 3.0 + 4.0) / (4.0)),
|
||||
turso_core::Value::Null
|
||||
]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user