mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-08 10:44:20 +01:00
Merge pull request #58 from pereman2/avg
core: Avg aggregation function
This commit is contained in:
@@ -4,12 +4,43 @@ use std::rc::Rc;
|
||||
use crate::pager::Pager;
|
||||
use crate::schema::Schema;
|
||||
use crate::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE};
|
||||
use crate::vdbe::{Insn, Program, ProgramBuilder};
|
||||
use crate::vdbe::{AggFunc, Insn, Program, ProgramBuilder};
|
||||
use anyhow::Result;
|
||||
use sqlite3_parser::ast::{
|
||||
Expr, Literal, OneSelect, PragmaBody, QualifiedName, Select, Stmt, UnaryOperator,
|
||||
};
|
||||
|
||||
enum AggregationFunc {
|
||||
Avg,
|
||||
Count,
|
||||
GroupConcat,
|
||||
Max,
|
||||
Min,
|
||||
StringAgg,
|
||||
Sum,
|
||||
Total,
|
||||
}
|
||||
|
||||
struct ColumnInfo {
|
||||
func: Option<AggregationFunc>,
|
||||
args: Option<Vec<Expr>>,
|
||||
columns_to_allocate: usize, /* number of result columns this col will result on */
|
||||
}
|
||||
|
||||
impl ColumnInfo {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
func: None,
|
||||
args: None,
|
||||
columns_to_allocate: 1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_aggregation_function(&self) -> bool {
|
||||
return self.func.is_some();
|
||||
}
|
||||
}
|
||||
|
||||
/// Translate SQL statement into bytecode program.
|
||||
pub fn translate(
|
||||
schema: &Schema,
|
||||
@@ -31,7 +62,14 @@ fn translate_select(schema: &Schema, select: Select) -> Result<Program> {
|
||||
let start_offset = program.offset();
|
||||
let limit_reg = if let Some(limit) = select.limit {
|
||||
assert!(limit.offset.is_none());
|
||||
Some(translate_expr(&mut program, None, None, &limit.expr))
|
||||
let target_register = program.alloc_register();
|
||||
Some(translate_expr(
|
||||
&mut program,
|
||||
None,
|
||||
None,
|
||||
&limit.expr,
|
||||
target_register,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -62,18 +100,60 @@ fn translate_select(schema: &Schema, select: Select) -> Result<Program> {
|
||||
program.emit_insn(Insn::OpenReadAwait);
|
||||
program.emit_insn(Insn::RewindAsync { cursor_id });
|
||||
let rewind_await_offset = program.emit_placeholder();
|
||||
let (register_start, register_end) =
|
||||
translate_columns(&mut program, Some(cursor_id), Some(table), columns);
|
||||
program.emit_insn(Insn::ResultRow {
|
||||
register_start,
|
||||
register_end,
|
||||
});
|
||||
let info_per_columns = analyze_columns(&columns, Some(table));
|
||||
let exist_aggregation = info_per_columns.iter().any(|info| info.func.is_some());
|
||||
let (register_start, register_end) = translate_columns(
|
||||
&mut program,
|
||||
Some(cursor_id),
|
||||
Some(table),
|
||||
&columns,
|
||||
&info_per_columns,
|
||||
exist_aggregation,
|
||||
);
|
||||
if exist_aggregation {
|
||||
// Only one ResultRow will occurr with aggregations.
|
||||
program.emit_insn(Insn::NextAsync { cursor_id });
|
||||
program.emit_insn(Insn::NextAwait {
|
||||
cursor_id,
|
||||
pc_if_next: rewind_await_offset,
|
||||
});
|
||||
let mut target = register_start;
|
||||
for info in &info_per_columns {
|
||||
if info.is_aggregation_function() {
|
||||
let func = match info.func.as_ref().unwrap() {
|
||||
AggregationFunc::Avg => AggFunc::Avg,
|
||||
AggregationFunc::Count => todo!(),
|
||||
AggregationFunc::GroupConcat => todo!(),
|
||||
AggregationFunc::Max => todo!(),
|
||||
AggregationFunc::Min => todo!(),
|
||||
AggregationFunc::StringAgg => todo!(),
|
||||
AggregationFunc::Sum => todo!(),
|
||||
AggregationFunc::Total => todo!(),
|
||||
};
|
||||
program.emit_insn(Insn::AggFinal {
|
||||
register: target,
|
||||
func,
|
||||
});
|
||||
}
|
||||
target += info.columns_to_allocate;
|
||||
}
|
||||
// only one result row
|
||||
program.emit_insn(Insn::ResultRow {
|
||||
register_start,
|
||||
register_end,
|
||||
});
|
||||
} else {
|
||||
program.emit_insn(Insn::ResultRow {
|
||||
register_start,
|
||||
register_end,
|
||||
});
|
||||
program.emit_insn(Insn::NextAsync { cursor_id });
|
||||
program.emit_insn(Insn::NextAwait {
|
||||
cursor_id,
|
||||
pc_if_next: rewind_await_offset,
|
||||
});
|
||||
}
|
||||
let limit_decr_insn = limit_reg.map(|_| program.emit_placeholder());
|
||||
program.emit_insn(Insn::NextAsync { cursor_id });
|
||||
program.emit_insn(Insn::NextAwait {
|
||||
cursor_id,
|
||||
pc_if_next: rewind_await_offset,
|
||||
});
|
||||
program.fixup_insn(
|
||||
rewind_await_offset,
|
||||
Insn::RewindAwait {
|
||||
@@ -88,8 +168,17 @@ fn translate_select(schema: &Schema, select: Select) -> Result<Program> {
|
||||
from: None,
|
||||
..
|
||||
} => {
|
||||
let (register_start, register_end) =
|
||||
translate_columns(&mut program, None, None, columns);
|
||||
let info_per_columns = analyze_columns(&columns, None);
|
||||
let exist_aggregation = info_per_columns.iter().any(|info| info.func.is_some());
|
||||
assert!(!exist_aggregation);
|
||||
let (register_start, register_end) = translate_columns(
|
||||
&mut program,
|
||||
None,
|
||||
None,
|
||||
&columns,
|
||||
&info_per_columns,
|
||||
exist_aggregation,
|
||||
);
|
||||
program.emit_insn(Insn::ResultRow {
|
||||
register_start,
|
||||
register_end,
|
||||
@@ -126,13 +215,33 @@ fn translate_columns(
|
||||
program: &mut ProgramBuilder,
|
||||
cursor_id: Option<usize>,
|
||||
table: Option<&crate::schema::Table>,
|
||||
columns: Vec<sqlite3_parser::ast::ResultColumn>,
|
||||
columns: &Vec<sqlite3_parser::ast::ResultColumn>,
|
||||
info_per_columns: &Vec<ColumnInfo>,
|
||||
exist_aggregation: bool,
|
||||
) -> (usize, usize) {
|
||||
let register_start = program.next_free_register();
|
||||
for col in columns {
|
||||
translate_column(program, cursor_id, table, col);
|
||||
}
|
||||
|
||||
// allocate one register as output for each col
|
||||
let registers: usize = info_per_columns
|
||||
.iter()
|
||||
.map(|col| col.columns_to_allocate)
|
||||
.sum();
|
||||
program.alloc_registers(registers);
|
||||
let register_end = program.next_free_register();
|
||||
|
||||
let mut target = register_start;
|
||||
for (col, info) in columns.iter().zip(info_per_columns) {
|
||||
translate_column(
|
||||
program,
|
||||
cursor_id,
|
||||
table,
|
||||
col,
|
||||
info,
|
||||
exist_aggregation,
|
||||
target,
|
||||
);
|
||||
target += info.columns_to_allocate;
|
||||
}
|
||||
(register_start, register_end)
|
||||
}
|
||||
|
||||
@@ -140,24 +249,36 @@ fn translate_column(
|
||||
program: &mut ProgramBuilder,
|
||||
cursor_id: Option<usize>,
|
||||
table: Option<&crate::schema::Table>,
|
||||
col: sqlite3_parser::ast::ResultColumn,
|
||||
col: &sqlite3_parser::ast::ResultColumn,
|
||||
info: &ColumnInfo,
|
||||
exist_aggregation: bool, // notify this column there is aggregation going on in other columns (or this one)
|
||||
target_register: usize, // where to store the result, in case of star it will be the start of registers added
|
||||
) {
|
||||
if exist_aggregation && !info.is_aggregation_function() {
|
||||
// FIXME: let's do nothing
|
||||
return;
|
||||
}
|
||||
|
||||
match col {
|
||||
sqlite3_parser::ast::ResultColumn::Expr(expr, _) => {
|
||||
let _ = translate_expr(program, cursor_id, table, &expr);
|
||||
if info.is_aggregation_function() {
|
||||
let _ =
|
||||
translate_aggregation(program, cursor_id, table, &expr, info, target_register);
|
||||
} else {
|
||||
let _ = translate_expr(program, cursor_id, table, &expr, target_register);
|
||||
}
|
||||
}
|
||||
sqlite3_parser::ast::ResultColumn::Star => {
|
||||
for (i, col) in table.unwrap().columns.iter().enumerate() {
|
||||
let dest = program.alloc_register();
|
||||
if col.is_rowid_alias() {
|
||||
program.emit_insn(Insn::RowId {
|
||||
cursor_id: cursor_id.unwrap(),
|
||||
dest,
|
||||
dest: target_register + i,
|
||||
});
|
||||
} else {
|
||||
program.emit_insn(Insn::Column {
|
||||
column: i,
|
||||
dest,
|
||||
dest: target_register + i,
|
||||
cursor_id: cursor_id.unwrap(),
|
||||
});
|
||||
}
|
||||
@@ -167,11 +288,72 @@ fn translate_column(
|
||||
}
|
||||
}
|
||||
|
||||
fn analyze_columns(
|
||||
columns: &Vec<sqlite3_parser::ast::ResultColumn>,
|
||||
table: Option<&crate::schema::Table>,
|
||||
) -> Vec<ColumnInfo> {
|
||||
let mut column_information_list = Vec::new();
|
||||
column_information_list.reserve(columns.len());
|
||||
|
||||
for column in columns {
|
||||
let mut info = ColumnInfo::new();
|
||||
info.columns_to_allocate = 1;
|
||||
if let sqlite3_parser::ast::ResultColumn::Star = column {
|
||||
info.columns_to_allocate = table.unwrap().columns.len();
|
||||
} else {
|
||||
analyze_column(column, &mut info);
|
||||
}
|
||||
column_information_list.push(info);
|
||||
}
|
||||
column_information_list
|
||||
}
|
||||
|
||||
/*
|
||||
Walk column expression trying to find aggregation functions. If it finds one it will save information
|
||||
about it.
|
||||
*/
|
||||
fn analyze_column(column: &sqlite3_parser::ast::ResultColumn, column_info_out: &mut ColumnInfo) {
|
||||
match column {
|
||||
sqlite3_parser::ast::ResultColumn::Expr(expr, _) => match expr {
|
||||
Expr::FunctionCall {
|
||||
name,
|
||||
distinctness: _,
|
||||
args,
|
||||
filter_over: _,
|
||||
} => {
|
||||
let func_type = match name.0.as_str() {
|
||||
"avg" => Some(AggregationFunc::Avg),
|
||||
"count" => Some(AggregationFunc::Count),
|
||||
"group_concat" => Some(AggregationFunc::GroupConcat),
|
||||
"max" => Some(AggregationFunc::Max),
|
||||
"min" => Some(AggregationFunc::Min),
|
||||
"string_agg" => Some(AggregationFunc::StringAgg),
|
||||
"sum" => Some(AggregationFunc::Sum),
|
||||
"total" => Some(AggregationFunc::Total),
|
||||
_ => None,
|
||||
};
|
||||
if func_type.is_none() {
|
||||
analyze_column(column, column_info_out);
|
||||
} else {
|
||||
column_info_out.func = func_type;
|
||||
// TODO(pere): use lifetimes for args? Arenas would be lovely here :(
|
||||
column_info_out.args = args.clone();
|
||||
}
|
||||
}
|
||||
Expr::FunctionCallStar { .. } => todo!(),
|
||||
_ => {}
|
||||
},
|
||||
sqlite3_parser::ast::ResultColumn::Star => {}
|
||||
sqlite3_parser::ast::ResultColumn::TableStar(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn translate_expr(
|
||||
program: &mut ProgramBuilder,
|
||||
cursor_id: Option<usize>,
|
||||
table: Option<&crate::schema::Table>,
|
||||
expr: &Expr,
|
||||
target_register: usize,
|
||||
) -> usize {
|
||||
match expr {
|
||||
Expr::Between { .. } => todo!(),
|
||||
@@ -185,20 +367,19 @@ fn translate_expr(
|
||||
Expr::FunctionCallStar { .. } => todo!(),
|
||||
Expr::Id(ident) => {
|
||||
let (idx, col) = table.unwrap().get_column(&ident.0).unwrap();
|
||||
let dest = program.alloc_register();
|
||||
if col.primary_key {
|
||||
program.emit_insn(Insn::RowId {
|
||||
cursor_id: cursor_id.unwrap(),
|
||||
dest,
|
||||
dest: target_register,
|
||||
});
|
||||
} else {
|
||||
program.emit_insn(Insn::Column {
|
||||
column: idx,
|
||||
dest,
|
||||
dest: target_register,
|
||||
cursor_id: cursor_id.unwrap(),
|
||||
});
|
||||
}
|
||||
dest
|
||||
target_register
|
||||
}
|
||||
Expr::InList { .. } => todo!(),
|
||||
Expr::InSelect { .. } => todo!(),
|
||||
@@ -207,20 +388,18 @@ fn translate_expr(
|
||||
Expr::Like { .. } => todo!(),
|
||||
Expr::Literal(lit) => match lit {
|
||||
Literal::Numeric(val) => {
|
||||
let dest = program.alloc_register();
|
||||
program.emit_insn(Insn::Integer {
|
||||
value: val.parse().unwrap(),
|
||||
dest,
|
||||
dest: target_register,
|
||||
});
|
||||
dest
|
||||
target_register
|
||||
}
|
||||
Literal::String(s) => {
|
||||
let dest = program.alloc_register();
|
||||
program.emit_insn(Insn::String8 {
|
||||
value: s[1..s.len() - 1].to_string(),
|
||||
dest,
|
||||
dest: target_register,
|
||||
});
|
||||
dest
|
||||
target_register
|
||||
}
|
||||
Literal::Blob(_) => todo!(),
|
||||
Literal::Keyword(_) => todo!(),
|
||||
@@ -240,6 +419,44 @@ fn translate_expr(
|
||||
}
|
||||
}
|
||||
|
||||
fn translate_aggregation(
|
||||
program: &mut ProgramBuilder,
|
||||
cursor_id: Option<usize>,
|
||||
table: Option<&crate::schema::Table>,
|
||||
expr: &Expr,
|
||||
info: &ColumnInfo,
|
||||
target_register: usize,
|
||||
) -> Result<usize> {
|
||||
let _ = expr;
|
||||
assert!(info.func.is_some());
|
||||
let func = info.func.as_ref().unwrap();
|
||||
let args = info.args.as_ref().unwrap();
|
||||
let dest = match func {
|
||||
AggregationFunc::Avg => {
|
||||
if args.len() != 1 {
|
||||
anyhow::bail!("Parse error: avg bad number of arguments");
|
||||
}
|
||||
let expr = &args[0];
|
||||
let expr_reg = program.alloc_register();
|
||||
let _ = translate_expr(program, cursor_id, table, &expr, expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
func: crate::vdbe::AggFunc::Avg,
|
||||
});
|
||||
target_register
|
||||
}
|
||||
AggregationFunc::Count => todo!(),
|
||||
AggregationFunc::GroupConcat => todo!(),
|
||||
AggregationFunc::Max => todo!(),
|
||||
AggregationFunc::Min => todo!(),
|
||||
AggregationFunc::StringAgg => todo!(),
|
||||
AggregationFunc::Sum => todo!(),
|
||||
AggregationFunc::Total => todo!(),
|
||||
};
|
||||
Ok(dest)
|
||||
}
|
||||
|
||||
fn translate_pragma(
|
||||
name: &QualifiedName,
|
||||
body: Option<PragmaBody>,
|
||||
|
||||
@@ -11,6 +11,11 @@ pub enum Value<'a> {
|
||||
Blob(&'a Vec<u8>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AggContext {
|
||||
Avg(f64, usize), // acc and count
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum OwnedValue {
|
||||
Null,
|
||||
@@ -18,6 +23,7 @@ pub enum OwnedValue {
|
||||
Float(f64),
|
||||
Text(Rc<String>),
|
||||
Blob(Rc<Vec<u8>>),
|
||||
Agg(Box<AggContext>),
|
||||
}
|
||||
|
||||
pub fn to_value(value: &OwnedValue) -> Value<'_> {
|
||||
@@ -27,6 +33,10 @@ pub fn to_value(value: &OwnedValue) -> Value<'_> {
|
||||
OwnedValue::Float(f) => Value::Float(*f),
|
||||
OwnedValue::Text(s) => Value::Text(s),
|
||||
OwnedValue::Blob(b) => Value::Blob(b),
|
||||
OwnedValue::Agg(a) => match a.as_ref() {
|
||||
AggContext::Avg(acc, _count) => Value::Float(*acc), // we assume aggfinal was called
|
||||
_ => todo!(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
87
core/vdbe.rs
87
core/vdbe.rs
@@ -1,9 +1,10 @@
|
||||
use crate::btree::BTreeCursor;
|
||||
use crate::pager::Pager;
|
||||
use crate::types::{Cursor, CursorResult, OwnedValue, Record};
|
||||
use crate::types::{AggContext, Cursor, CursorResult, OwnedValue, Record};
|
||||
|
||||
use anyhow::Result;
|
||||
use core::fmt;
|
||||
use std::borrow::BorrowMut;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::BTreeMap;
|
||||
use std::rc::Rc;
|
||||
@@ -98,6 +99,30 @@ pub enum Insn {
|
||||
reg: usize,
|
||||
target_pc: BranchOffset,
|
||||
},
|
||||
|
||||
AggStep {
|
||||
acc_reg: usize,
|
||||
col: usize,
|
||||
func: AggFunc,
|
||||
},
|
||||
|
||||
AggFinal {
|
||||
register: usize,
|
||||
func: AggFunc,
|
||||
},
|
||||
}
|
||||
|
||||
pub enum AggFunc {
|
||||
Avg,
|
||||
}
|
||||
|
||||
impl AggFunc {
|
||||
fn to_string(&self) -> &str {
|
||||
match self {
|
||||
AggFunc::Avg => "avg",
|
||||
_ => "unknown",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProgramBuilder {
|
||||
@@ -121,6 +146,12 @@ impl ProgramBuilder {
|
||||
reg
|
||||
}
|
||||
|
||||
pub fn alloc_registers(&mut self, amount: usize) -> usize {
|
||||
let reg = self.next_free_register;
|
||||
self.next_free_register += amount;
|
||||
reg
|
||||
}
|
||||
|
||||
pub fn next_free_register(&self) -> usize {
|
||||
self.next_free_register
|
||||
}
|
||||
@@ -334,6 +365,42 @@ impl Program {
|
||||
}
|
||||
_ => unreachable!("DecrJumpZero on non-integer register"),
|
||||
},
|
||||
Insn::AggStep { acc_reg, col, func } => {
|
||||
if let OwnedValue::Null = &state.registers[*acc_reg] {
|
||||
state.registers[*acc_reg] =
|
||||
OwnedValue::Agg(Box::new(AggContext::Avg(0.0, 0)));
|
||||
}
|
||||
match func {
|
||||
AggFunc::Avg => {
|
||||
let col = state.registers[*col].clone();
|
||||
let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut()
|
||||
else {
|
||||
unreachable!();
|
||||
};
|
||||
let AggContext::Avg(acc, count) = agg.borrow_mut();
|
||||
match col {
|
||||
OwnedValue::Integer(i) => *acc += i as f64,
|
||||
OwnedValue::Float(f) => *acc += f,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
*count += 1;
|
||||
}
|
||||
};
|
||||
state.pc += 1;
|
||||
}
|
||||
Insn::AggFinal { register, func } => {
|
||||
match func {
|
||||
AggFunc::Avg => {
|
||||
let OwnedValue::Agg(agg) = state.registers[*register].borrow_mut()
|
||||
else {
|
||||
unreachable!();
|
||||
};
|
||||
let AggContext::Avg(acc, count) = agg.borrow_mut();
|
||||
*acc /= *count as f64
|
||||
}
|
||||
};
|
||||
state.pc += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -547,6 +614,24 @@ fn insn_to_str(addr: usize, insn: &Insn) -> String {
|
||||
IntValue::Usize(0),
|
||||
"".to_string(),
|
||||
),
|
||||
Insn::AggStep { func, acc_reg, col } => (
|
||||
"AggStep",
|
||||
IntValue::Usize(0),
|
||||
IntValue::Usize(*col),
|
||||
IntValue::Usize(*acc_reg),
|
||||
func.to_string(),
|
||||
IntValue::Usize(0),
|
||||
format!("accum=r[{}] step({})", *acc_reg, *col),
|
||||
),
|
||||
Insn::AggFinal { register, func } => (
|
||||
"AggFinal",
|
||||
IntValue::Usize(0),
|
||||
IntValue::Usize(*register),
|
||||
IntValue::Usize(0),
|
||||
func.to_string(),
|
||||
IntValue::Usize(0),
|
||||
format!("accum=r[{}]", *register),
|
||||
),
|
||||
};
|
||||
format!(
|
||||
"{:<4} {:<13} {:<4} {:<4} {:<4} {:<13} {:<2} {}",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import sqlite3
|
||||
from faker import Faker
|
||||
import random
|
||||
|
||||
conn = sqlite3.connect('database.db')
|
||||
cursor = conn.cursor()
|
||||
@@ -17,7 +18,8 @@ cursor.execute('''
|
||||
address TEXT,
|
||||
city TEXT,
|
||||
state TEXT,
|
||||
zipcode TEXT
|
||||
zipcode TEXT,
|
||||
age INTEGER
|
||||
)
|
||||
''')
|
||||
|
||||
@@ -31,11 +33,14 @@ for _ in range(10000):
|
||||
city = fake.city()
|
||||
state = fake.state_abbr()
|
||||
zipcode = fake.zipcode()
|
||||
age = random.randint(0, 100) % 99
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO users (first_name, last_name, email, phone_number, address, city, state, zipcode)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (first_name, last_name, email, phone_number, address, city, state, zipcode))
|
||||
INSERT INTO users (first_name, last_name, email, phone_number, address, city, state, zipcode, age)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (first_name, last_name, email, phone_number, address, city, state, zipcode, age))
|
||||
|
||||
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
Reference in New Issue
Block a user