diff --git a/core/translate.rs b/core/translate.rs index 3c20fe73d..0bb17e310 100644 --- a/core/translate.rs +++ b/core/translate.rs @@ -4,12 +4,43 @@ use std::rc::Rc; use crate::pager::Pager; use crate::schema::Schema; use crate::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE}; -use crate::vdbe::{Insn, Program, ProgramBuilder}; +use crate::vdbe::{AggFunc, Insn, Program, ProgramBuilder}; use anyhow::Result; use sqlite3_parser::ast::{ Expr, Literal, OneSelect, PragmaBody, QualifiedName, Select, Stmt, UnaryOperator, }; +enum AggregationFunc { + Avg, + Count, + GroupConcat, + Max, + Min, + StringAgg, + Sum, + Total, +} + +struct ColumnInfo { + func: Option, + args: Option>, + columns_to_allocate: usize, /* number of result columns this col will result on */ +} + +impl ColumnInfo { + pub fn new() -> Self { + Self { + func: None, + args: None, + columns_to_allocate: 1, + } + } + + pub fn is_aggregation_function(&self) -> bool { + return self.func.is_some(); + } +} + /// Translate SQL statement into bytecode program. pub fn translate( schema: &Schema, @@ -31,7 +62,14 @@ fn translate_select(schema: &Schema, select: Select) -> Result { let start_offset = program.offset(); let limit_reg = if let Some(limit) = select.limit { assert!(limit.offset.is_none()); - Some(translate_expr(&mut program, None, None, &limit.expr)) + let target_register = program.alloc_register(); + Some(translate_expr( + &mut program, + None, + None, + &limit.expr, + target_register, + )) } else { None }; @@ -62,18 +100,60 @@ fn translate_select(schema: &Schema, select: Select) -> Result { program.emit_insn(Insn::OpenReadAwait); program.emit_insn(Insn::RewindAsync { cursor_id }); let rewind_await_offset = program.emit_placeholder(); - let (register_start, register_end) = - translate_columns(&mut program, Some(cursor_id), Some(table), columns); - program.emit_insn(Insn::ResultRow { - register_start, - register_end, - }); + let info_per_columns = analyze_columns(&columns, Some(table)); + let exist_aggregation = info_per_columns.iter().any(|info| info.func.is_some()); + let (register_start, register_end) = translate_columns( + &mut program, + Some(cursor_id), + Some(table), + &columns, + &info_per_columns, + exist_aggregation, + ); + if exist_aggregation { + // Only one ResultRow will occurr with aggregations. + program.emit_insn(Insn::NextAsync { cursor_id }); + program.emit_insn(Insn::NextAwait { + cursor_id, + pc_if_next: rewind_await_offset, + }); + let mut target = register_start; + for info in &info_per_columns { + if info.is_aggregation_function() { + let func = match info.func.as_ref().unwrap() { + AggregationFunc::Avg => AggFunc::Avg, + AggregationFunc::Count => todo!(), + AggregationFunc::GroupConcat => todo!(), + AggregationFunc::Max => todo!(), + AggregationFunc::Min => todo!(), + AggregationFunc::StringAgg => todo!(), + AggregationFunc::Sum => todo!(), + AggregationFunc::Total => todo!(), + }; + program.emit_insn(Insn::AggFinal { + register: target, + func, + }); + } + target += info.columns_to_allocate; + } + // only one result row + program.emit_insn(Insn::ResultRow { + register_start, + register_end, + }); + } else { + program.emit_insn(Insn::ResultRow { + register_start, + register_end, + }); + program.emit_insn(Insn::NextAsync { cursor_id }); + program.emit_insn(Insn::NextAwait { + cursor_id, + pc_if_next: rewind_await_offset, + }); + } let limit_decr_insn = limit_reg.map(|_| program.emit_placeholder()); - program.emit_insn(Insn::NextAsync { cursor_id }); - program.emit_insn(Insn::NextAwait { - cursor_id, - pc_if_next: rewind_await_offset, - }); program.fixup_insn( rewind_await_offset, Insn::RewindAwait { @@ -88,8 +168,17 @@ fn translate_select(schema: &Schema, select: Select) -> Result { from: None, .. } => { - let (register_start, register_end) = - translate_columns(&mut program, None, None, columns); + let info_per_columns = analyze_columns(&columns, None); + let exist_aggregation = info_per_columns.iter().any(|info| info.func.is_some()); + assert!(!exist_aggregation); + let (register_start, register_end) = translate_columns( + &mut program, + None, + None, + &columns, + &info_per_columns, + exist_aggregation, + ); program.emit_insn(Insn::ResultRow { register_start, register_end, @@ -126,13 +215,33 @@ fn translate_columns( program: &mut ProgramBuilder, cursor_id: Option, table: Option<&crate::schema::Table>, - columns: Vec, + columns: &Vec, + info_per_columns: &Vec, + exist_aggregation: bool, ) -> (usize, usize) { let register_start = program.next_free_register(); - for col in columns { - translate_column(program, cursor_id, table, col); - } + + // allocate one register as output for each col + let registers: usize = info_per_columns + .iter() + .map(|col| col.columns_to_allocate) + .sum(); + program.alloc_registers(registers); let register_end = program.next_free_register(); + + let mut target = register_start; + for (col, info) in columns.iter().zip(info_per_columns) { + translate_column( + program, + cursor_id, + table, + col, + info, + exist_aggregation, + target, + ); + target += info.columns_to_allocate; + } (register_start, register_end) } @@ -140,24 +249,36 @@ fn translate_column( program: &mut ProgramBuilder, cursor_id: Option, table: Option<&crate::schema::Table>, - col: sqlite3_parser::ast::ResultColumn, + col: &sqlite3_parser::ast::ResultColumn, + info: &ColumnInfo, + exist_aggregation: bool, // notify this column there is aggregation going on in other columns (or this one) + target_register: usize, // where to store the result, in case of star it will be the start of registers added ) { + if exist_aggregation && !info.is_aggregation_function() { + // FIXME: let's do nothing + return; + } + match col { sqlite3_parser::ast::ResultColumn::Expr(expr, _) => { - let _ = translate_expr(program, cursor_id, table, &expr); + if info.is_aggregation_function() { + let _ = + translate_aggregation(program, cursor_id, table, &expr, info, target_register); + } else { + let _ = translate_expr(program, cursor_id, table, &expr, target_register); + } } sqlite3_parser::ast::ResultColumn::Star => { for (i, col) in table.unwrap().columns.iter().enumerate() { - let dest = program.alloc_register(); if col.is_rowid_alias() { program.emit_insn(Insn::RowId { cursor_id: cursor_id.unwrap(), - dest, + dest: target_register + i, }); } else { program.emit_insn(Insn::Column { column: i, - dest, + dest: target_register + i, cursor_id: cursor_id.unwrap(), }); } @@ -167,11 +288,72 @@ fn translate_column( } } +fn analyze_columns( + columns: &Vec, + table: Option<&crate::schema::Table>, +) -> Vec { + let mut column_information_list = Vec::new(); + column_information_list.reserve(columns.len()); + + for column in columns { + let mut info = ColumnInfo::new(); + info.columns_to_allocate = 1; + if let sqlite3_parser::ast::ResultColumn::Star = column { + info.columns_to_allocate = table.unwrap().columns.len(); + } else { + analyze_column(column, &mut info); + } + column_information_list.push(info); + } + column_information_list +} + +/* + Walk column expression trying to find aggregation functions. If it finds one it will save information + about it. +*/ +fn analyze_column(column: &sqlite3_parser::ast::ResultColumn, column_info_out: &mut ColumnInfo) { + match column { + sqlite3_parser::ast::ResultColumn::Expr(expr, _) => match expr { + Expr::FunctionCall { + name, + distinctness: _, + args, + filter_over: _, + } => { + let func_type = match name.0.as_str() { + "avg" => Some(AggregationFunc::Avg), + "count" => Some(AggregationFunc::Count), + "group_concat" => Some(AggregationFunc::GroupConcat), + "max" => Some(AggregationFunc::Max), + "min" => Some(AggregationFunc::Min), + "string_agg" => Some(AggregationFunc::StringAgg), + "sum" => Some(AggregationFunc::Sum), + "total" => Some(AggregationFunc::Total), + _ => None, + }; + if func_type.is_none() { + analyze_column(column, column_info_out); + } else { + column_info_out.func = func_type; + // TODO(pere): use lifetimes for args? Arenas would be lovely here :( + column_info_out.args = args.clone(); + } + } + Expr::FunctionCallStar { .. } => todo!(), + _ => {} + }, + sqlite3_parser::ast::ResultColumn::Star => {} + sqlite3_parser::ast::ResultColumn::TableStar(_) => {} + } +} + fn translate_expr( program: &mut ProgramBuilder, cursor_id: Option, table: Option<&crate::schema::Table>, expr: &Expr, + target_register: usize, ) -> usize { match expr { Expr::Between { .. } => todo!(), @@ -185,20 +367,19 @@ fn translate_expr( Expr::FunctionCallStar { .. } => todo!(), Expr::Id(ident) => { let (idx, col) = table.unwrap().get_column(&ident.0).unwrap(); - let dest = program.alloc_register(); if col.primary_key { program.emit_insn(Insn::RowId { cursor_id: cursor_id.unwrap(), - dest, + dest: target_register, }); } else { program.emit_insn(Insn::Column { column: idx, - dest, + dest: target_register, cursor_id: cursor_id.unwrap(), }); } - dest + target_register } Expr::InList { .. } => todo!(), Expr::InSelect { .. } => todo!(), @@ -207,20 +388,18 @@ fn translate_expr( Expr::Like { .. } => todo!(), Expr::Literal(lit) => match lit { Literal::Numeric(val) => { - let dest = program.alloc_register(); program.emit_insn(Insn::Integer { value: val.parse().unwrap(), - dest, + dest: target_register, }); - dest + target_register } Literal::String(s) => { - let dest = program.alloc_register(); program.emit_insn(Insn::String8 { value: s[1..s.len() - 1].to_string(), - dest, + dest: target_register, }); - dest + target_register } Literal::Blob(_) => todo!(), Literal::Keyword(_) => todo!(), @@ -240,6 +419,44 @@ fn translate_expr( } } +fn translate_aggregation( + program: &mut ProgramBuilder, + cursor_id: Option, + table: Option<&crate::schema::Table>, + expr: &Expr, + info: &ColumnInfo, + target_register: usize, +) -> Result { + let _ = expr; + assert!(info.func.is_some()); + let func = info.func.as_ref().unwrap(); + let args = info.args.as_ref().unwrap(); + let dest = match func { + AggregationFunc::Avg => { + if args.len() != 1 { + anyhow::bail!("Parse error: avg bad number of arguments"); + } + let expr = &args[0]; + let expr_reg = program.alloc_register(); + let _ = translate_expr(program, cursor_id, table, &expr, expr_reg); + program.emit_insn(Insn::AggStep { + acc_reg: target_register, + col: expr_reg, + func: crate::vdbe::AggFunc::Avg, + }); + target_register + } + AggregationFunc::Count => todo!(), + AggregationFunc::GroupConcat => todo!(), + AggregationFunc::Max => todo!(), + AggregationFunc::Min => todo!(), + AggregationFunc::StringAgg => todo!(), + AggregationFunc::Sum => todo!(), + AggregationFunc::Total => todo!(), + }; + Ok(dest) +} + fn translate_pragma( name: &QualifiedName, body: Option, diff --git a/core/types.rs b/core/types.rs index b4d55ca62..86ef35beb 100644 --- a/core/types.rs +++ b/core/types.rs @@ -11,6 +11,11 @@ pub enum Value<'a> { Blob(&'a Vec), } +#[derive(Debug, Clone, PartialEq)] +pub enum AggContext { + Avg(f64, usize), // acc and count +} + #[derive(Debug, Clone, PartialEq)] pub enum OwnedValue { Null, @@ -18,6 +23,7 @@ pub enum OwnedValue { Float(f64), Text(Rc), Blob(Rc>), + Agg(Box), } pub fn to_value(value: &OwnedValue) -> Value<'_> { @@ -27,6 +33,10 @@ pub fn to_value(value: &OwnedValue) -> Value<'_> { OwnedValue::Float(f) => Value::Float(*f), OwnedValue::Text(s) => Value::Text(s), OwnedValue::Blob(b) => Value::Blob(b), + OwnedValue::Agg(a) => match a.as_ref() { + AggContext::Avg(acc, _count) => Value::Float(*acc), // we assume aggfinal was called + _ => todo!(), + }, } } diff --git a/core/vdbe.rs b/core/vdbe.rs index 26c69ca93..4c04c962f 100644 --- a/core/vdbe.rs +++ b/core/vdbe.rs @@ -1,9 +1,10 @@ use crate::btree::BTreeCursor; use crate::pager::Pager; -use crate::types::{Cursor, CursorResult, OwnedValue, Record}; +use crate::types::{AggContext, Cursor, CursorResult, OwnedValue, Record}; use anyhow::Result; use core::fmt; +use std::borrow::BorrowMut; use std::cell::RefCell; use std::collections::BTreeMap; use std::rc::Rc; @@ -98,6 +99,30 @@ pub enum Insn { reg: usize, target_pc: BranchOffset, }, + + AggStep { + acc_reg: usize, + col: usize, + func: AggFunc, + }, + + AggFinal { + register: usize, + func: AggFunc, + }, +} + +pub enum AggFunc { + Avg, +} + +impl AggFunc { + fn to_string(&self) -> &str { + match self { + AggFunc::Avg => "avg", + _ => "unknown", + } + } } pub struct ProgramBuilder { @@ -121,6 +146,12 @@ impl ProgramBuilder { reg } + pub fn alloc_registers(&mut self, amount: usize) -> usize { + let reg = self.next_free_register; + self.next_free_register += amount; + reg + } + pub fn next_free_register(&self) -> usize { self.next_free_register } @@ -334,6 +365,42 @@ impl Program { } _ => unreachable!("DecrJumpZero on non-integer register"), }, + Insn::AggStep { acc_reg, col, func } => { + if let OwnedValue::Null = &state.registers[*acc_reg] { + state.registers[*acc_reg] = + OwnedValue::Agg(Box::new(AggContext::Avg(0.0, 0))); + } + match func { + AggFunc::Avg => { + let col = state.registers[*col].clone(); + let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut() + else { + unreachable!(); + }; + let AggContext::Avg(acc, count) = agg.borrow_mut(); + match col { + OwnedValue::Integer(i) => *acc += i as f64, + OwnedValue::Float(f) => *acc += f, + _ => unreachable!(), + } + *count += 1; + } + }; + state.pc += 1; + } + Insn::AggFinal { register, func } => { + match func { + AggFunc::Avg => { + let OwnedValue::Agg(agg) = state.registers[*register].borrow_mut() + else { + unreachable!(); + }; + let AggContext::Avg(acc, count) = agg.borrow_mut(); + *acc /= *count as f64 + } + }; + state.pc += 1; + } } } } @@ -547,6 +614,24 @@ fn insn_to_str(addr: usize, insn: &Insn) -> String { IntValue::Usize(0), "".to_string(), ), + Insn::AggStep { func, acc_reg, col } => ( + "AggStep", + IntValue::Usize(0), + IntValue::Usize(*col), + IntValue::Usize(*acc_reg), + func.to_string(), + IntValue::Usize(0), + format!("accum=r[{}] step({})", *acc_reg, *col), + ), + Insn::AggFinal { register, func } => ( + "AggFinal", + IntValue::Usize(0), + IntValue::Usize(*register), + IntValue::Usize(0), + func.to_string(), + IntValue::Usize(0), + format!("accum=r[{}]", *register), + ), }; format!( "{:<4} {:<13} {:<4} {:<4} {:<4} {:<13} {:<2} {}", diff --git a/testing/gen-database.py b/testing/gen-database.py index 53cc7c44c..2d73c678a 100755 --- a/testing/gen-database.py +++ b/testing/gen-database.py @@ -2,6 +2,7 @@ import sqlite3 from faker import Faker +import random conn = sqlite3.connect('database.db') cursor = conn.cursor() @@ -17,7 +18,8 @@ cursor.execute(''' address TEXT, city TEXT, state TEXT, - zipcode TEXT + zipcode TEXT, + age INTEGER ) ''') @@ -31,11 +33,14 @@ for _ in range(10000): city = fake.city() state = fake.state_abbr() zipcode = fake.zipcode() + age = random.randint(0, 100) % 99 cursor.execute(''' - INSERT INTO users (first_name, last_name, email, phone_number, address, city, state, zipcode) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', (first_name, last_name, email, phone_number, address, city, state, zipcode)) + INSERT INTO users (first_name, last_name, email, phone_number, address, city, state, zipcode, age) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', (first_name, last_name, email, phone_number, address, city, state, zipcode, age)) + + conn.commit() conn.close()