From 9268560a511f88b0fd51a5ffce2668cab89415a0 Mon Sep 17 00:00:00 2001 From: Ramkarthik Krishnamurthy Date: Sat, 13 Jul 2024 00:55:40 +0530 Subject: [PATCH] Implements group concat aggregate function --- core/translate.rs | 40 ++++++++++++++++++++++++++++++++++++++-- core/types.rs | 6 ++++++ core/vdbe.rs | 40 ++++++++++++++++++++++++++++++++++++++-- testing/all.test | 8 ++++++++ 4 files changed, 90 insertions(+), 4 deletions(-) diff --git a/core/translate.rs b/core/translate.rs index 99a3c0a50..569977e96 100644 --- a/core/translate.rs +++ b/core/translate.rs @@ -8,7 +8,7 @@ use crate::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE}; use crate::util::normalize_ident; use crate::vdbe::{Insn, Program, ProgramBuilder}; use anyhow::Result; -use sqlite3_parser::ast::{self, Expr}; +use sqlite3_parser::ast::{self, Expr, Literal}; struct Select { columns: Vec, @@ -685,6 +685,7 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Avg, }); target_register @@ -701,11 +702,42 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Count, }); target_register } - AggFunc::GroupConcat => todo!(), + AggFunc::GroupConcat => { + if args.len() != 1 && args.len() != 2 { + anyhow::bail!("Parse error: group_concat bad number of arguments"); + } + + let expr = &args[0]; + let expr_reg = program.alloc_register(); + let delimiter_reg = program.alloc_register(); + let _ = translate_expr(program, select, expr, expr_reg); + if args.len() == 2 { + let delimiter = match &args[1] { + ast::Expr::Id(ident) => &ident.0, + ast::Expr::Literal(Literal::String(s)) => &s, + _ => anyhow::bail!("Incorrect delimiter parameter"), + }; + let delimiter = ast::Expr::Literal(Literal::String(delimiter.to_string())); + let _ = translate_expr(program, select, &delimiter, delimiter_reg); + } else { + let delimiter = ast::Expr::Literal(Literal::String(String::from("\",\""))); + let _ = translate_expr(program, select, &delimiter, delimiter_reg); + } + + program.emit_insn(Insn::AggStep { + acc_reg: target_register, + col: expr_reg, + delimiter: delimiter_reg, + func: AggFunc::GroupConcat, + }); + + target_register + } AggFunc::Max => { if args.len() != 1 { anyhow::bail!("Parse error: max bad number of arguments"); @@ -716,6 +748,7 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Max, }); target_register @@ -730,6 +763,7 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Min, }); target_register @@ -745,6 +779,7 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Sum, }); target_register @@ -759,6 +794,7 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Total, }); target_register diff --git a/core/types.rs b/core/types.rs index 5f99f1482..f6913dcd8 100644 --- a/core/types.rs +++ b/core/types.rs @@ -49,6 +49,7 @@ impl Display for OwnedValue { AggContext::Count(count) => write!(f, "{}", count), AggContext::Max(max) => write!(f, "{}", max), AggContext::Min(min) => write!(f, "{}", min), + AggContext::GroupConcat(s) => write!(f, "{}", s), }, OwnedValue::Record(r) => write!(f, "{:?}", r), } @@ -62,6 +63,7 @@ pub enum AggContext { Count(OwnedValue), Max(OwnedValue), Min(OwnedValue), + GroupConcat(OwnedValue), } impl std::ops::Add for OwnedValue { @@ -81,6 +83,9 @@ impl std::ops::Add for OwnedValue { (OwnedValue::Float(float_left), OwnedValue::Float(float_right)) => { OwnedValue::Float(float_left + float_right) } + (OwnedValue::Text(string_left), OwnedValue::Text(string_right)) => { + OwnedValue::Text(Rc::new(string_left.to_string() + &string_right.to_string())) + } (lhs, OwnedValue::Null) => lhs, (OwnedValue::Null, rhs) => rhs, _ => OwnedValue::Float(0.0), @@ -171,6 +176,7 @@ pub fn to_value(value: &OwnedValue) -> Value<'_> { AggContext::Count(count) => to_value(count), AggContext::Max(max) => to_value(max), AggContext::Min(min) => to_value(min), + AggContext::GroupConcat(s) => to_value(s), }, OwnedValue::Record(_) => todo!(), } diff --git a/core/vdbe.rs b/core/vdbe.rs index e99451d9d..9c75ef961 100644 --- a/core/vdbe.rs +++ b/core/vdbe.rs @@ -168,6 +168,7 @@ pub enum Insn { AggStep { acc_reg: usize, col: usize, + delimiter: usize, func: AggFunc, }, @@ -671,7 +672,12 @@ impl Program { } _ => unreachable!("DecrJumpZero on non-integer register"), }, - Insn::AggStep { acc_reg, col, func } => { + Insn::AggStep { + acc_reg, + col, + delimiter, + func, + } => { if let OwnedValue::Null = &state.registers[*acc_reg] { state.registers[*acc_reg] = match func { AggFunc::Avg => OwnedValue::Agg(Box::new(AggContext::Avg( @@ -718,6 +724,9 @@ impl Program { } } } + AggFunc::GroupConcat => OwnedValue::Agg(Box::new( + AggContext::GroupConcat(OwnedValue::Text(Rc::new("".to_string()))), + )), _ => { todo!(); } @@ -821,6 +830,27 @@ impl Program { } } } + AggFunc::GroupConcat => { + let col = state.registers[*col].clone(); + let delimiter = state.registers[*delimiter].clone(); + let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut() + else { + unreachable!(); + }; + let AggContext::GroupConcat(acc) = agg.borrow_mut() else { + unreachable!(); + }; + // let AggContext::GroupConcat(acc, _col, delimiter) = + // state.registers.borrow_mut() + // else { + // unreachable!(); + // }; + if acc.to_string().len() == 0 { + *acc = col; + } else { + *acc += delimiter + col; + } + } _ => { todo!(); } @@ -841,6 +871,7 @@ impl Program { AggFunc::Count => {} AggFunc::Max => {} AggFunc::Min => {} + AggFunc::GroupConcat => {} _ => { todo!(); } @@ -1240,7 +1271,12 @@ fn insn_to_str(addr: BranchOffset, insn: &Insn, indent: String) -> String { 0, "".to_string(), ), - Insn::AggStep { func, acc_reg, col } => ( + Insn::AggStep { + func, + acc_reg, + delimiter: _, + col, + } => ( "AggStep", 0, *col as i32, diff --git a/testing/all.test b/testing/all.test index 252ae03e9..15856a2ef 100755 --- a/testing/all.test +++ b/testing/all.test @@ -61,6 +61,14 @@ do_execsql_test select-min { SELECT min(age) FROM users; } {1} +do_execsql_test select-group-concat { + SELECT group_concat(name) FROM products; +} {hat,cap,shirt,sweater,sweatshirt,shorts,jeans,sneakers,boots,coat,accessories} + +do_execsql_test select-group-concat-with-delimiter { + SELECT group_concat(name, ';') FROM products; +} {hat;cap;shirt;sweater;sweatshirt;shorts;jeans;sneakers;boots;coat;accessories} + do_execsql_test select-limit-0 { SELECT id FROM users LIMIT 0; } {}