Merge pull request #58 from pereman2/avg

core: Avg aggregation function
This commit is contained in:
Pekka Enberg
2024-06-30 17:30:50 +03:00
committed by GitHub
4 changed files with 356 additions and 39 deletions

View File

@@ -4,12 +4,43 @@ use std::rc::Rc;
use crate::pager::Pager;
use crate::schema::Schema;
use crate::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE};
use crate::vdbe::{Insn, Program, ProgramBuilder};
use crate::vdbe::{AggFunc, Insn, Program, ProgramBuilder};
use anyhow::Result;
use sqlite3_parser::ast::{
Expr, Literal, OneSelect, PragmaBody, QualifiedName, Select, Stmt, UnaryOperator,
};
enum AggregationFunc {
Avg,
Count,
GroupConcat,
Max,
Min,
StringAgg,
Sum,
Total,
}
struct ColumnInfo {
func: Option<AggregationFunc>,
args: Option<Vec<Expr>>,
columns_to_allocate: usize, /* number of result columns this col will result on */
}
impl ColumnInfo {
pub fn new() -> Self {
Self {
func: None,
args: None,
columns_to_allocate: 1,
}
}
pub fn is_aggregation_function(&self) -> bool {
return self.func.is_some();
}
}
/// Translate SQL statement into bytecode program.
pub fn translate(
schema: &Schema,
@@ -31,7 +62,14 @@ fn translate_select(schema: &Schema, select: Select) -> Result<Program> {
let start_offset = program.offset();
let limit_reg = if let Some(limit) = select.limit {
assert!(limit.offset.is_none());
Some(translate_expr(&mut program, None, None, &limit.expr))
let target_register = program.alloc_register();
Some(translate_expr(
&mut program,
None,
None,
&limit.expr,
target_register,
))
} else {
None
};
@@ -62,18 +100,60 @@ fn translate_select(schema: &Schema, select: Select) -> Result<Program> {
program.emit_insn(Insn::OpenReadAwait);
program.emit_insn(Insn::RewindAsync { cursor_id });
let rewind_await_offset = program.emit_placeholder();
let (register_start, register_end) =
translate_columns(&mut program, Some(cursor_id), Some(table), columns);
program.emit_insn(Insn::ResultRow {
register_start,
register_end,
});
let info_per_columns = analyze_columns(&columns, Some(table));
let exist_aggregation = info_per_columns.iter().any(|info| info.func.is_some());
let (register_start, register_end) = translate_columns(
&mut program,
Some(cursor_id),
Some(table),
&columns,
&info_per_columns,
exist_aggregation,
);
if exist_aggregation {
// Only one ResultRow will occurr with aggregations.
program.emit_insn(Insn::NextAsync { cursor_id });
program.emit_insn(Insn::NextAwait {
cursor_id,
pc_if_next: rewind_await_offset,
});
let mut target = register_start;
for info in &info_per_columns {
if info.is_aggregation_function() {
let func = match info.func.as_ref().unwrap() {
AggregationFunc::Avg => AggFunc::Avg,
AggregationFunc::Count => todo!(),
AggregationFunc::GroupConcat => todo!(),
AggregationFunc::Max => todo!(),
AggregationFunc::Min => todo!(),
AggregationFunc::StringAgg => todo!(),
AggregationFunc::Sum => todo!(),
AggregationFunc::Total => todo!(),
};
program.emit_insn(Insn::AggFinal {
register: target,
func,
});
}
target += info.columns_to_allocate;
}
// only one result row
program.emit_insn(Insn::ResultRow {
register_start,
register_end,
});
} else {
program.emit_insn(Insn::ResultRow {
register_start,
register_end,
});
program.emit_insn(Insn::NextAsync { cursor_id });
program.emit_insn(Insn::NextAwait {
cursor_id,
pc_if_next: rewind_await_offset,
});
}
let limit_decr_insn = limit_reg.map(|_| program.emit_placeholder());
program.emit_insn(Insn::NextAsync { cursor_id });
program.emit_insn(Insn::NextAwait {
cursor_id,
pc_if_next: rewind_await_offset,
});
program.fixup_insn(
rewind_await_offset,
Insn::RewindAwait {
@@ -88,8 +168,17 @@ fn translate_select(schema: &Schema, select: Select) -> Result<Program> {
from: None,
..
} => {
let (register_start, register_end) =
translate_columns(&mut program, None, None, columns);
let info_per_columns = analyze_columns(&columns, None);
let exist_aggregation = info_per_columns.iter().any(|info| info.func.is_some());
assert!(!exist_aggregation);
let (register_start, register_end) = translate_columns(
&mut program,
None,
None,
&columns,
&info_per_columns,
exist_aggregation,
);
program.emit_insn(Insn::ResultRow {
register_start,
register_end,
@@ -126,13 +215,33 @@ fn translate_columns(
program: &mut ProgramBuilder,
cursor_id: Option<usize>,
table: Option<&crate::schema::Table>,
columns: Vec<sqlite3_parser::ast::ResultColumn>,
columns: &Vec<sqlite3_parser::ast::ResultColumn>,
info_per_columns: &Vec<ColumnInfo>,
exist_aggregation: bool,
) -> (usize, usize) {
let register_start = program.next_free_register();
for col in columns {
translate_column(program, cursor_id, table, col);
}
// allocate one register as output for each col
let registers: usize = info_per_columns
.iter()
.map(|col| col.columns_to_allocate)
.sum();
program.alloc_registers(registers);
let register_end = program.next_free_register();
let mut target = register_start;
for (col, info) in columns.iter().zip(info_per_columns) {
translate_column(
program,
cursor_id,
table,
col,
info,
exist_aggregation,
target,
);
target += info.columns_to_allocate;
}
(register_start, register_end)
}
@@ -140,24 +249,36 @@ fn translate_column(
program: &mut ProgramBuilder,
cursor_id: Option<usize>,
table: Option<&crate::schema::Table>,
col: sqlite3_parser::ast::ResultColumn,
col: &sqlite3_parser::ast::ResultColumn,
info: &ColumnInfo,
exist_aggregation: bool, // notify this column there is aggregation going on in other columns (or this one)
target_register: usize, // where to store the result, in case of star it will be the start of registers added
) {
if exist_aggregation && !info.is_aggregation_function() {
// FIXME: let's do nothing
return;
}
match col {
sqlite3_parser::ast::ResultColumn::Expr(expr, _) => {
let _ = translate_expr(program, cursor_id, table, &expr);
if info.is_aggregation_function() {
let _ =
translate_aggregation(program, cursor_id, table, &expr, info, target_register);
} else {
let _ = translate_expr(program, cursor_id, table, &expr, target_register);
}
}
sqlite3_parser::ast::ResultColumn::Star => {
for (i, col) in table.unwrap().columns.iter().enumerate() {
let dest = program.alloc_register();
if col.is_rowid_alias() {
program.emit_insn(Insn::RowId {
cursor_id: cursor_id.unwrap(),
dest,
dest: target_register + i,
});
} else {
program.emit_insn(Insn::Column {
column: i,
dest,
dest: target_register + i,
cursor_id: cursor_id.unwrap(),
});
}
@@ -167,11 +288,72 @@ fn translate_column(
}
}
fn analyze_columns(
columns: &Vec<sqlite3_parser::ast::ResultColumn>,
table: Option<&crate::schema::Table>,
) -> Vec<ColumnInfo> {
let mut column_information_list = Vec::new();
column_information_list.reserve(columns.len());
for column in columns {
let mut info = ColumnInfo::new();
info.columns_to_allocate = 1;
if let sqlite3_parser::ast::ResultColumn::Star = column {
info.columns_to_allocate = table.unwrap().columns.len();
} else {
analyze_column(column, &mut info);
}
column_information_list.push(info);
}
column_information_list
}
/*
Walk column expression trying to find aggregation functions. If it finds one it will save information
about it.
*/
fn analyze_column(column: &sqlite3_parser::ast::ResultColumn, column_info_out: &mut ColumnInfo) {
match column {
sqlite3_parser::ast::ResultColumn::Expr(expr, _) => match expr {
Expr::FunctionCall {
name,
distinctness: _,
args,
filter_over: _,
} => {
let func_type = match name.0.as_str() {
"avg" => Some(AggregationFunc::Avg),
"count" => Some(AggregationFunc::Count),
"group_concat" => Some(AggregationFunc::GroupConcat),
"max" => Some(AggregationFunc::Max),
"min" => Some(AggregationFunc::Min),
"string_agg" => Some(AggregationFunc::StringAgg),
"sum" => Some(AggregationFunc::Sum),
"total" => Some(AggregationFunc::Total),
_ => None,
};
if func_type.is_none() {
analyze_column(column, column_info_out);
} else {
column_info_out.func = func_type;
// TODO(pere): use lifetimes for args? Arenas would be lovely here :(
column_info_out.args = args.clone();
}
}
Expr::FunctionCallStar { .. } => todo!(),
_ => {}
},
sqlite3_parser::ast::ResultColumn::Star => {}
sqlite3_parser::ast::ResultColumn::TableStar(_) => {}
}
}
fn translate_expr(
program: &mut ProgramBuilder,
cursor_id: Option<usize>,
table: Option<&crate::schema::Table>,
expr: &Expr,
target_register: usize,
) -> usize {
match expr {
Expr::Between { .. } => todo!(),
@@ -185,20 +367,19 @@ fn translate_expr(
Expr::FunctionCallStar { .. } => todo!(),
Expr::Id(ident) => {
let (idx, col) = table.unwrap().get_column(&ident.0).unwrap();
let dest = program.alloc_register();
if col.primary_key {
program.emit_insn(Insn::RowId {
cursor_id: cursor_id.unwrap(),
dest,
dest: target_register,
});
} else {
program.emit_insn(Insn::Column {
column: idx,
dest,
dest: target_register,
cursor_id: cursor_id.unwrap(),
});
}
dest
target_register
}
Expr::InList { .. } => todo!(),
Expr::InSelect { .. } => todo!(),
@@ -207,20 +388,18 @@ fn translate_expr(
Expr::Like { .. } => todo!(),
Expr::Literal(lit) => match lit {
Literal::Numeric(val) => {
let dest = program.alloc_register();
program.emit_insn(Insn::Integer {
value: val.parse().unwrap(),
dest,
dest: target_register,
});
dest
target_register
}
Literal::String(s) => {
let dest = program.alloc_register();
program.emit_insn(Insn::String8 {
value: s[1..s.len() - 1].to_string(),
dest,
dest: target_register,
});
dest
target_register
}
Literal::Blob(_) => todo!(),
Literal::Keyword(_) => todo!(),
@@ -240,6 +419,44 @@ fn translate_expr(
}
}
fn translate_aggregation(
program: &mut ProgramBuilder,
cursor_id: Option<usize>,
table: Option<&crate::schema::Table>,
expr: &Expr,
info: &ColumnInfo,
target_register: usize,
) -> Result<usize> {
let _ = expr;
assert!(info.func.is_some());
let func = info.func.as_ref().unwrap();
let args = info.args.as_ref().unwrap();
let dest = match func {
AggregationFunc::Avg => {
if args.len() != 1 {
anyhow::bail!("Parse error: avg bad number of arguments");
}
let expr = &args[0];
let expr_reg = program.alloc_register();
let _ = translate_expr(program, cursor_id, table, &expr, expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
func: crate::vdbe::AggFunc::Avg,
});
target_register
}
AggregationFunc::Count => todo!(),
AggregationFunc::GroupConcat => todo!(),
AggregationFunc::Max => todo!(),
AggregationFunc::Min => todo!(),
AggregationFunc::StringAgg => todo!(),
AggregationFunc::Sum => todo!(),
AggregationFunc::Total => todo!(),
};
Ok(dest)
}
fn translate_pragma(
name: &QualifiedName,
body: Option<PragmaBody>,

View File

@@ -11,6 +11,11 @@ pub enum Value<'a> {
Blob(&'a Vec<u8>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum AggContext {
Avg(f64, usize), // acc and count
}
#[derive(Debug, Clone, PartialEq)]
pub enum OwnedValue {
Null,
@@ -18,6 +23,7 @@ pub enum OwnedValue {
Float(f64),
Text(Rc<String>),
Blob(Rc<Vec<u8>>),
Agg(Box<AggContext>),
}
pub fn to_value(value: &OwnedValue) -> Value<'_> {
@@ -27,6 +33,10 @@ pub fn to_value(value: &OwnedValue) -> Value<'_> {
OwnedValue::Float(f) => Value::Float(*f),
OwnedValue::Text(s) => Value::Text(s),
OwnedValue::Blob(b) => Value::Blob(b),
OwnedValue::Agg(a) => match a.as_ref() {
AggContext::Avg(acc, _count) => Value::Float(*acc), // we assume aggfinal was called
_ => todo!(),
},
}
}

View File

@@ -1,9 +1,10 @@
use crate::btree::BTreeCursor;
use crate::pager::Pager;
use crate::types::{Cursor, CursorResult, OwnedValue, Record};
use crate::types::{AggContext, Cursor, CursorResult, OwnedValue, Record};
use anyhow::Result;
use core::fmt;
use std::borrow::BorrowMut;
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::rc::Rc;
@@ -98,6 +99,30 @@ pub enum Insn {
reg: usize,
target_pc: BranchOffset,
},
AggStep {
acc_reg: usize,
col: usize,
func: AggFunc,
},
AggFinal {
register: usize,
func: AggFunc,
},
}
pub enum AggFunc {
Avg,
}
impl AggFunc {
fn to_string(&self) -> &str {
match self {
AggFunc::Avg => "avg",
_ => "unknown",
}
}
}
pub struct ProgramBuilder {
@@ -121,6 +146,12 @@ impl ProgramBuilder {
reg
}
pub fn alloc_registers(&mut self, amount: usize) -> usize {
let reg = self.next_free_register;
self.next_free_register += amount;
reg
}
pub fn next_free_register(&self) -> usize {
self.next_free_register
}
@@ -334,6 +365,42 @@ impl Program {
}
_ => unreachable!("DecrJumpZero on non-integer register"),
},
Insn::AggStep { acc_reg, col, func } => {
if let OwnedValue::Null = &state.registers[*acc_reg] {
state.registers[*acc_reg] =
OwnedValue::Agg(Box::new(AggContext::Avg(0.0, 0)));
}
match func {
AggFunc::Avg => {
let col = state.registers[*col].clone();
let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut()
else {
unreachable!();
};
let AggContext::Avg(acc, count) = agg.borrow_mut();
match col {
OwnedValue::Integer(i) => *acc += i as f64,
OwnedValue::Float(f) => *acc += f,
_ => unreachable!(),
}
*count += 1;
}
};
state.pc += 1;
}
Insn::AggFinal { register, func } => {
match func {
AggFunc::Avg => {
let OwnedValue::Agg(agg) = state.registers[*register].borrow_mut()
else {
unreachable!();
};
let AggContext::Avg(acc, count) = agg.borrow_mut();
*acc /= *count as f64
}
};
state.pc += 1;
}
}
}
}
@@ -547,6 +614,24 @@ fn insn_to_str(addr: usize, insn: &Insn) -> String {
IntValue::Usize(0),
"".to_string(),
),
Insn::AggStep { func, acc_reg, col } => (
"AggStep",
IntValue::Usize(0),
IntValue::Usize(*col),
IntValue::Usize(*acc_reg),
func.to_string(),
IntValue::Usize(0),
format!("accum=r[{}] step({})", *acc_reg, *col),
),
Insn::AggFinal { register, func } => (
"AggFinal",
IntValue::Usize(0),
IntValue::Usize(*register),
IntValue::Usize(0),
func.to_string(),
IntValue::Usize(0),
format!("accum=r[{}]", *register),
),
};
format!(
"{:<4} {:<13} {:<4} {:<4} {:<4} {:<13} {:<2} {}",

View File

@@ -2,6 +2,7 @@
import sqlite3
from faker import Faker
import random
conn = sqlite3.connect('database.db')
cursor = conn.cursor()
@@ -17,7 +18,8 @@ cursor.execute('''
address TEXT,
city TEXT,
state TEXT,
zipcode TEXT
zipcode TEXT,
age INTEGER
)
''')
@@ -31,11 +33,14 @@ for _ in range(10000):
city = fake.city()
state = fake.state_abbr()
zipcode = fake.zipcode()
age = random.randint(0, 100) % 99
cursor.execute('''
INSERT INTO users (first_name, last_name, email, phone_number, address, city, state, zipcode)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (first_name, last_name, email, phone_number, address, city, state, zipcode))
INSERT INTO users (first_name, last_name, email, phone_number, address, city, state, zipcode, age)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (first_name, last_name, email, phone_number, address, city, state, zipcode, age))
conn.commit()
conn.close()