Implements group concat aggregate function

This commit is contained in:
Ramkarthik Krishnamurthy
2024-07-13 00:55:40 +05:30
parent 2540f7d127
commit 9268560a51
4 changed files with 90 additions and 4 deletions

View File

@@ -8,7 +8,7 @@ use crate::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE};
use crate::util::normalize_ident;
use crate::vdbe::{Insn, Program, ProgramBuilder};
use anyhow::Result;
use sqlite3_parser::ast::{self, Expr};
use sqlite3_parser::ast::{self, Expr, Literal};
struct Select {
columns: Vec<ast::ResultColumn>,
@@ -685,6 +685,7 @@ fn translate_aggregation(
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Avg,
});
target_register
@@ -701,11 +702,42 @@ fn translate_aggregation(
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Count,
});
target_register
}
AggFunc::GroupConcat => todo!(),
AggFunc::GroupConcat => {
if args.len() != 1 && args.len() != 2 {
anyhow::bail!("Parse error: group_concat bad number of arguments");
}
let expr = &args[0];
let expr_reg = program.alloc_register();
let delimiter_reg = program.alloc_register();
let _ = translate_expr(program, select, expr, expr_reg);
if args.len() == 2 {
let delimiter = match &args[1] {
ast::Expr::Id(ident) => &ident.0,
ast::Expr::Literal(Literal::String(s)) => &s,
_ => anyhow::bail!("Incorrect delimiter parameter"),
};
let delimiter = ast::Expr::Literal(Literal::String(delimiter.to_string()));
let _ = translate_expr(program, select, &delimiter, delimiter_reg);
} else {
let delimiter = ast::Expr::Literal(Literal::String(String::from("\",\"")));
let _ = translate_expr(program, select, &delimiter, delimiter_reg);
}
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: delimiter_reg,
func: AggFunc::GroupConcat,
});
target_register
}
AggFunc::Max => {
if args.len() != 1 {
anyhow::bail!("Parse error: max bad number of arguments");
@@ -716,6 +748,7 @@ fn translate_aggregation(
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Max,
});
target_register
@@ -730,6 +763,7 @@ fn translate_aggregation(
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Min,
});
target_register
@@ -745,6 +779,7 @@ fn translate_aggregation(
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Sum,
});
target_register
@@ -759,6 +794,7 @@ fn translate_aggregation(
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Total,
});
target_register

View File

@@ -49,6 +49,7 @@ impl Display for OwnedValue {
AggContext::Count(count) => write!(f, "{}", count),
AggContext::Max(max) => write!(f, "{}", max),
AggContext::Min(min) => write!(f, "{}", min),
AggContext::GroupConcat(s) => write!(f, "{}", s),
},
OwnedValue::Record(r) => write!(f, "{:?}", r),
}
@@ -62,6 +63,7 @@ pub enum AggContext {
Count(OwnedValue),
Max(OwnedValue),
Min(OwnedValue),
GroupConcat(OwnedValue),
}
impl std::ops::Add<OwnedValue> for OwnedValue {
@@ -81,6 +83,9 @@ impl std::ops::Add<OwnedValue> for OwnedValue {
(OwnedValue::Float(float_left), OwnedValue::Float(float_right)) => {
OwnedValue::Float(float_left + float_right)
}
(OwnedValue::Text(string_left), OwnedValue::Text(string_right)) => {
OwnedValue::Text(Rc::new(string_left.to_string() + &string_right.to_string()))
}
(lhs, OwnedValue::Null) => lhs,
(OwnedValue::Null, rhs) => rhs,
_ => OwnedValue::Float(0.0),
@@ -171,6 +176,7 @@ pub fn to_value(value: &OwnedValue) -> Value<'_> {
AggContext::Count(count) => to_value(count),
AggContext::Max(max) => to_value(max),
AggContext::Min(min) => to_value(min),
AggContext::GroupConcat(s) => to_value(s),
},
OwnedValue::Record(_) => todo!(),
}

View File

@@ -168,6 +168,7 @@ pub enum Insn {
AggStep {
acc_reg: usize,
col: usize,
delimiter: usize,
func: AggFunc,
},
@@ -671,7 +672,12 @@ impl Program {
}
_ => unreachable!("DecrJumpZero on non-integer register"),
},
Insn::AggStep { acc_reg, col, func } => {
Insn::AggStep {
acc_reg,
col,
delimiter,
func,
} => {
if let OwnedValue::Null = &state.registers[*acc_reg] {
state.registers[*acc_reg] = match func {
AggFunc::Avg => OwnedValue::Agg(Box::new(AggContext::Avg(
@@ -718,6 +724,9 @@ impl Program {
}
}
}
AggFunc::GroupConcat => OwnedValue::Agg(Box::new(
AggContext::GroupConcat(OwnedValue::Text(Rc::new("".to_string()))),
)),
_ => {
todo!();
}
@@ -821,6 +830,27 @@ impl Program {
}
}
}
AggFunc::GroupConcat => {
let col = state.registers[*col].clone();
let delimiter = state.registers[*delimiter].clone();
let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut()
else {
unreachable!();
};
let AggContext::GroupConcat(acc) = agg.borrow_mut() else {
unreachable!();
};
// let AggContext::GroupConcat(acc, _col, delimiter) =
// state.registers.borrow_mut()
// else {
// unreachable!();
// };
if acc.to_string().len() == 0 {
*acc = col;
} else {
*acc += delimiter + col;
}
}
_ => {
todo!();
}
@@ -841,6 +871,7 @@ impl Program {
AggFunc::Count => {}
AggFunc::Max => {}
AggFunc::Min => {}
AggFunc::GroupConcat => {}
_ => {
todo!();
}
@@ -1240,7 +1271,12 @@ fn insn_to_str(addr: BranchOffset, insn: &Insn, indent: String) -> String {
0,
"".to_string(),
),
Insn::AggStep { func, acc_reg, col } => (
Insn::AggStep {
func,
acc_reg,
delimiter: _,
col,
} => (
"AggStep",
0,
*col as i32,

View File

@@ -61,6 +61,14 @@ do_execsql_test select-min {
SELECT min(age) FROM users;
} {1}
do_execsql_test select-group-concat {
SELECT group_concat(name) FROM products;
} {hat,cap,shirt,sweater,sweatshirt,shorts,jeans,sneakers,boots,coat,accessories}
do_execsql_test select-group-concat-with-delimiter {
SELECT group_concat(name, ';') FROM products;
} {hat;cap;shirt;sweater;sweatshirt;shorts;jeans;sneakers;boots;coat;accessories}
do_execsql_test select-limit-0 {
SELECT id FROM users LIMIT 0;
} {}