diff --git a/core/translate.rs b/core/translate.rs index 3c20fe73d..e6307280f 100644 --- a/core/translate.rs +++ b/core/translate.rs @@ -4,12 +4,43 @@ use std::rc::Rc; use crate::pager::Pager; use crate::schema::Schema; use crate::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE}; -use crate::vdbe::{Insn, Program, ProgramBuilder}; +use crate::vdbe::{AggFunc, Insn, Program, ProgramBuilder}; use anyhow::Result; use sqlite3_parser::ast::{ Expr, Literal, OneSelect, PragmaBody, QualifiedName, Select, Stmt, UnaryOperator, }; +enum AggregationFunc { + Avg, + Count, + GroupConcat, + Max, + Min, + StringAgg, + Sum, + Total, +} + +struct ColumnAggregationInfo { + func: Option, + args: Option>, + columns_to_allocate: usize, /* number of result columns this col will result on */ +} + +impl ColumnAggregationInfo { + pub fn new() -> Self { + Self { + func: None, + args: None, + columns_to_allocate: 1, + } + } + + pub fn is_aggregation_function(&self) -> bool { + return self.func.is_some(); + } +} + /// Translate SQL statement into bytecode program. pub fn translate( schema: &Schema, @@ -31,7 +62,14 @@ fn translate_select(schema: &Schema, select: Select) -> Result { let start_offset = program.offset(); let limit_reg = if let Some(limit) = select.limit { assert!(limit.offset.is_none()); - Some(translate_expr(&mut program, None, None, &limit.expr)) + let target_register = program.alloc_register(); + Some(translate_expr( + &mut program, + None, + None, + &limit.expr, + target_register, + )) } else { None }; @@ -62,18 +100,59 @@ fn translate_select(schema: &Schema, select: Select) -> Result { program.emit_insn(Insn::OpenReadAwait); program.emit_insn(Insn::RewindAsync { cursor_id }); let rewind_await_offset = program.emit_placeholder(); - let (register_start, register_end) = - translate_columns(&mut program, Some(cursor_id), Some(table), columns); - program.emit_insn(Insn::ResultRow { - register_start, - register_end, - }); + let info_per_columns = analyze_columns(&columns, Some(table)); + let exist_aggregation = info_per_columns.iter().any(|info| info.func.is_some()); + let (register_start, register_end) = translate_columns( + &mut program, + Some(cursor_id), + Some(table), + &columns, + &info_per_columns, + exist_aggregation, + ); + if exist_aggregation { + program.emit_insn(Insn::NextAsync { cursor_id }); + program.emit_insn(Insn::NextAwait { + cursor_id, + pc_if_next: rewind_await_offset, + }); + let mut target = register_start; + for info in &info_per_columns { + if info.is_aggregation_function() { + let func = match info.func.as_ref().unwrap() { + AggregationFunc::Avg => AggFunc::Avg, + AggregationFunc::Count => todo!(), + AggregationFunc::GroupConcat => todo!(), + AggregationFunc::Max => todo!(), + AggregationFunc::Min => todo!(), + AggregationFunc::StringAgg => todo!(), + AggregationFunc::Sum => todo!(), + AggregationFunc::Total => todo!(), + }; + program.emit_insn(Insn::AggFinal { + register: target, + func, + }); + } + target += info.columns_to_allocate; + } + // only one result row + program.emit_insn(Insn::ResultRow { + register_start, + register_end, + }); + } else { + program.emit_insn(Insn::ResultRow { + register_start, + register_end, + }); + program.emit_insn(Insn::NextAsync { cursor_id }); + program.emit_insn(Insn::NextAwait { + cursor_id, + pc_if_next: rewind_await_offset, + }); + } let limit_decr_insn = limit_reg.map(|_| program.emit_placeholder()); - program.emit_insn(Insn::NextAsync { cursor_id }); - program.emit_insn(Insn::NextAwait { - cursor_id, - pc_if_next: rewind_await_offset, - }); program.fixup_insn( rewind_await_offset, Insn::RewindAwait { @@ -88,8 +167,16 @@ fn translate_select(schema: &Schema, select: Select) -> Result { from: None, .. } => { - let (register_start, register_end) = - translate_columns(&mut program, None, None, columns); + let info_per_columns = analyze_columns(&columns, None); + let exist_aggregation = info_per_columns.iter().any(|info| info.func.is_some()); + let (register_start, register_end) = translate_columns( + &mut program, + None, + None, + &columns, + &info_per_columns, + exist_aggregation, + ); program.emit_insn(Insn::ResultRow { register_start, register_end, @@ -126,13 +213,33 @@ fn translate_columns( program: &mut ProgramBuilder, cursor_id: Option, table: Option<&crate::schema::Table>, - columns: Vec, + columns: &Vec, + info_per_columns: &Vec, + exist_aggregation: bool, ) -> (usize, usize) { let register_start = program.next_free_register(); - for col in columns { - translate_column(program, cursor_id, table, col); - } + + // allocate one register as output for each col + let registers: usize = info_per_columns + .iter() + .map(|col| col.columns_to_allocate) + .sum(); + program.alloc_registers(registers); let register_end = program.next_free_register(); + + let mut target = register_start; + for (i, (col, info)) in columns.iter().zip(info_per_columns).enumerate() { + translate_column( + program, + cursor_id, + table, + col, + info, + exist_aggregation, + target, + ); + target += info.columns_to_allocate; + } (register_start, register_end) } @@ -140,24 +247,35 @@ fn translate_column( program: &mut ProgramBuilder, cursor_id: Option, table: Option<&crate::schema::Table>, - col: sqlite3_parser::ast::ResultColumn, + col: &sqlite3_parser::ast::ResultColumn, + info: &ColumnAggregationInfo, + exist_aggregation: bool, // notify this column there is aggregation going on in other columns (or this one) + target_register: usize, // where to store the result, in case of star it will be the start of registers added ) { + if exist_aggregation && !info.is_aggregation_function() { + // FIXME: let's do nothing + return; + } + match col { sqlite3_parser::ast::ResultColumn::Expr(expr, _) => { - let _ = translate_expr(program, cursor_id, table, &expr); + if info.is_aggregation_function() { + translate_aggregation(program, cursor_id, table, &expr, info, target_register); + } else { + let _ = translate_expr(program, cursor_id, table, &expr, target_register); + } } sqlite3_parser::ast::ResultColumn::Star => { for (i, col) in table.unwrap().columns.iter().enumerate() { - let dest = program.alloc_register(); if col.is_rowid_alias() { program.emit_insn(Insn::RowId { cursor_id: cursor_id.unwrap(), - dest, + dest: target_register + i, }); } else { program.emit_insn(Insn::Column { column: i, - dest, + dest: target_register + i, cursor_id: cursor_id.unwrap(), }); } @@ -167,11 +285,71 @@ fn translate_column( } } +fn analyze_columns( + columns: &Vec, + table: Option<&crate::schema::Table>, +) -> Vec { + let mut column_information_list = Vec::new(); + column_information_list.reserve(columns.len()); + + for column in columns { + let mut info = ColumnAggregationInfo::new(); + info.columns_to_allocate = 1; + if let sqlite3_parser::ast::ResultColumn::Star = column { + info.columns_to_allocate = table.unwrap().columns.len(); + } else { + analyze_column(column, &mut info); + } + column_information_list.push(info); + } + column_information_list +} + +fn analyze_column( + column: &sqlite3_parser::ast::ResultColumn, + column_info_out: &mut ColumnAggregationInfo, +) { + match column { + sqlite3_parser::ast::ResultColumn::Expr(expr, _) => match expr { + Expr::FunctionCall { + name, + distinctness, + args, + filter_over, + } => { + let func_type = match name.0.as_str() { + "avg" => Some(AggregationFunc::Avg), + "count" => Some(AggregationFunc::Count), + "group_concat" => Some(AggregationFunc::GroupConcat), + "max" => Some(AggregationFunc::Max), + "min" => Some(AggregationFunc::Min), + "string_agg" => Some(AggregationFunc::StringAgg), + "sum" => Some(AggregationFunc::Sum), + "total" => Some(AggregationFunc::Total), + _ => None, + }; + if func_type.is_none() { + analyze_column(column, column_info_out); + } else { + column_info_out.func = func_type; + // TODO(pere): use lifetimes for args? Arenas would be lovely here :( + column_info_out.args = args.clone(); + } + } + Expr::FunctionCallStar { .. } => todo!(), + _ => {} + }, + sqlite3_parser::ast::ResultColumn::Star => {} + sqlite3_parser::ast::ResultColumn::TableStar(_) => {} + } +} + fn translate_expr( program: &mut ProgramBuilder, cursor_id: Option, table: Option<&crate::schema::Table>, expr: &Expr, + target_register: usize, ) -> usize { match expr { Expr::Between { .. } => todo!(), @@ -185,20 +363,19 @@ fn translate_expr( Expr::FunctionCallStar { .. } => todo!(), Expr::Id(ident) => { let (idx, col) = table.unwrap().get_column(&ident.0).unwrap(); - let dest = program.alloc_register(); if col.primary_key { program.emit_insn(Insn::RowId { cursor_id: cursor_id.unwrap(), - dest, + dest: target_register, }); } else { program.emit_insn(Insn::Column { column: idx, - dest, + dest: target_register, cursor_id: cursor_id.unwrap(), }); } - dest + target_register } Expr::InList { .. } => todo!(), Expr::InSelect { .. } => todo!(), @@ -207,20 +384,18 @@ fn translate_expr( Expr::Like { .. } => todo!(), Expr::Literal(lit) => match lit { Literal::Numeric(val) => { - let dest = program.alloc_register(); program.emit_insn(Insn::Integer { value: val.parse().unwrap(), - dest, + dest: target_register, }); - dest + target_register } Literal::String(s) => { - let dest = program.alloc_register(); program.emit_insn(Insn::String8 { value: s[1..s.len() - 1].to_string(), - dest, + dest: target_register, }); - dest + target_register } Literal::Blob(_) => todo!(), Literal::Keyword(_) => todo!(), @@ -240,6 +415,43 @@ fn translate_expr( } } +fn translate_aggregation( + program: &mut ProgramBuilder, + cursor_id: Option, + table: Option<&crate::schema::Table>, + expr: &Expr, + info: &ColumnAggregationInfo, + target_register: usize, +) -> Result { + assert!(info.func.is_some()); + let func = info.func.as_ref().unwrap(); + let args = info.args.as_ref().unwrap(); + let dest = match func { + AggregationFunc::Avg => { + if args.len() != 1 { + anyhow::bail!("Parse error: avg bad number of arguments"); + } + let expr = &args[0]; + let expr_reg = program.alloc_register(); + let _ = translate_expr(program, cursor_id, table, &expr, expr_reg); + program.emit_insn(Insn::AggStep { + acc_reg: target_register, + col: expr_reg, + func: crate::vdbe::AggFunc::Avg, + }); + target_register + } + AggregationFunc::Count => todo!(), + AggregationFunc::GroupConcat => todo!(), + AggregationFunc::Max => todo!(), + AggregationFunc::Min => todo!(), + AggregationFunc::StringAgg => todo!(), + AggregationFunc::Sum => todo!(), + AggregationFunc::Total => todo!(), + }; + Ok(dest) +} + fn translate_pragma( name: &QualifiedName, body: Option, diff --git a/core/types.rs b/core/types.rs index b4d55ca62..acb7374ea 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1,4 +1,4 @@ -use std::{cell::Ref, rc::Rc}; +use std::{borrow::Borrow, cell::Ref, ops::Add, rc::Rc}; use anyhow::Result; @@ -11,6 +11,11 @@ pub enum Value<'a> { Blob(&'a Vec), } +#[derive(Debug, Clone, PartialEq)] +pub enum AggContext { + Avg(f64, usize), // acc and count +} + #[derive(Debug, Clone, PartialEq)] pub enum OwnedValue { Null, @@ -18,6 +23,40 @@ pub enum OwnedValue { Float(f64), Text(Rc), Blob(Rc>), + Agg(Box), +} + +impl std::ops::AddAssign for OwnedValue { + fn add_assign(&mut self, rhs: Self) { + let l = self.clone(); + *self = l.add(rhs); + } +} + +impl std::ops::Add for OwnedValue { + type Output = OwnedValue; + + fn add(self, rhs: Self) -> Self::Output { + assert!(matches!(&self, rhs)); + assert!(matches!(&self, OwnedValue::Integer(_)) || matches!(&self, OwnedValue::Float(_))); + match &self { + OwnedValue::Integer(l) => { + if let OwnedValue::Integer(r) = rhs { + OwnedValue::Integer(l + r) + } else { + panic!(); + } + } + OwnedValue::Float(l) => { + if let OwnedValue::Float(r) = rhs { + OwnedValue::Float(l + r) + } else { + panic!(); + } + } + _ => todo!(), + } + } } pub fn to_value(value: &OwnedValue) -> Value<'_> { @@ -27,6 +66,10 @@ pub fn to_value(value: &OwnedValue) -> Value<'_> { OwnedValue::Float(f) => Value::Float(*f), OwnedValue::Text(s) => Value::Text(s), OwnedValue::Blob(b) => Value::Blob(b), + OwnedValue::Agg(a) => match a.as_ref() { + AggContext::Avg(acc, count) => Value::Float(*acc), // we assume aggfinal was called + _ => todo!(), + }, } } diff --git a/core/vdbe.rs b/core/vdbe.rs index 26c69ca93..e47fe3ff6 100644 --- a/core/vdbe.rs +++ b/core/vdbe.rs @@ -1,9 +1,10 @@ use crate::btree::BTreeCursor; use crate::pager::Pager; -use crate::types::{Cursor, CursorResult, OwnedValue, Record}; +use crate::types::{AggContext, Cursor, CursorResult, OwnedValue, Record}; use anyhow::Result; use core::fmt; +use std::borrow::{Borrow, BorrowMut}; use std::cell::RefCell; use std::collections::BTreeMap; use std::rc::Rc; @@ -98,6 +99,26 @@ pub enum Insn { reg: usize, target_pc: BranchOffset, }, + + AggStep { + acc_reg: usize, + col: usize, + func: AggFunc, + }, + + AggFinal { + register: usize, + func: AggFunc, + }, + + Copy { + register_start: usize, + register_end: usize, + }, +} + +pub enum AggFunc { + Avg, } pub struct ProgramBuilder { @@ -121,6 +142,12 @@ impl ProgramBuilder { reg } + pub fn alloc_registers(&mut self, amount: usize) -> usize { + let reg = self.next_free_register; + self.next_free_register += amount; + reg + } + pub fn next_free_register(&self) -> usize { self.next_free_register } @@ -334,6 +361,46 @@ impl Program { } _ => unreachable!("DecrJumpZero on non-integer register"), }, + Insn::AggStep { acc_reg, col, func } => { + if let OwnedValue::Null = &state.registers[*acc_reg] { + state.registers[*acc_reg] = + OwnedValue::Agg(Box::new(AggContext::Avg(0.0, 0))); + } + match func { + AggFunc::Avg => { + let col = state.registers[*col].clone(); + let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut() + else { + unreachable!(); + }; + let AggContext::Avg(acc, count) = agg.borrow_mut(); + match col { + OwnedValue::Integer(i) => *acc += i as f64, + OwnedValue::Float(f) => *acc += f, + _ => unreachable!(), + } + *count += 1; + } + }; + state.pc += 1; + } + Insn::AggFinal { register, func } => { + match func { + AggFunc::Avg => { + let OwnedValue::Agg(agg) = state.registers[*register].borrow_mut() + else { + unreachable!(); + }; + let AggContext::Avg(acc, count) = agg.borrow_mut(); + *acc /= *count as f64 + } + }; + state.pc += 1; + } + Insn::Copy { + register_start, + register_end, + } => todo!(), } } } @@ -547,6 +614,28 @@ fn insn_to_str(addr: usize, insn: &Insn) -> String { IntValue::Usize(0), "".to_string(), ), + Insn::AggStep { func, acc_reg, col } => ( + "AggStep", + IntValue::Usize(0), + IntValue::Usize(*col), + IntValue::Usize(*acc_reg), + "avg", + IntValue::Usize(0), + format!("accum=r[{}] step({})", *acc_reg, *col), + ), + Insn::AggFinal { register, func } => ( + "AggFinal", + IntValue::Usize(0), + IntValue::Usize(*register), + IntValue::Usize(0), + "avg", + IntValue::Usize(0), + format!("accum=r[{}]", *register), + ), + Insn::Copy { + register_start, + register_end, + } => todo!(), }; format!( "{:<4} {:<13} {:<4} {:<4} {:<4} {:<13} {:<2} {}",