diff --git a/core/translate.rs b/core/translate.rs index aaaeccc20..e38ab4f01 100644 --- a/core/translate.rs +++ b/core/translate.rs @@ -8,7 +8,7 @@ use crate::sqlite3_ondisk::{DatabaseHeader, MIN_PAGE_CACHE_SIZE}; use crate::util::normalize_ident; use crate::vdbe::{BranchOffset, Insn, Program, ProgramBuilder}; use anyhow::Result; -use sqlite3_parser::ast::{self, Expr}; +use sqlite3_parser::ast::{self, Expr, Literal}; struct Select { columns: Vec, @@ -678,6 +678,7 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Avg, }); target_register @@ -694,11 +695,56 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Count, }); target_register } - AggFunc::GroupConcat => todo!(), + AggFunc::GroupConcat => { + if args.len() != 1 && args.len() != 2 { + anyhow::bail!("Parse error: group_concat bad number of arguments"); + } + + let expr_reg = program.alloc_register(); + let delimiter_reg = program.alloc_register(); + + let expr = &args[0]; + let delimiter_expr: ast::Expr; + + if args.len() == 2 { + match &args[1] { + ast::Expr::Id(ident) => { + if ident.0.starts_with("\"") { + delimiter_expr = ast::Expr::Literal(Literal::String(ident.0.to_string())); + } else { + delimiter_expr = args[1].clone(); + } + }, + ast::Expr::Literal(Literal::String(s)) => { + delimiter_expr = ast::Expr::Literal(Literal::String(s.to_string())); + }, + _ => anyhow::bail!("Incorrect delimiter parameter"), + }; + } else { + delimiter_expr = ast::Expr::Literal(Literal::String(String::from("\",\""))); + } + + if let Err(error) = translate_expr(program, select, expr, expr_reg) { + anyhow::bail!(error); + } + if let Err(error) = translate_expr(program, select, &delimiter_expr, delimiter_reg) { + anyhow::bail!(error); + } + + program.emit_insn(Insn::AggStep { + acc_reg: target_register, + col: expr_reg, + delimiter: delimiter_reg, + func: AggFunc::GroupConcat, + }); + + target_register + } AggFunc::Max => { if args.len() != 1 { anyhow::bail!("Parse error: max bad number of arguments"); @@ -709,6 +755,7 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Max, }); target_register @@ -723,11 +770,53 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Min, }); target_register } - AggFunc::StringAgg => todo!(), + AggFunc::StringAgg => { + if args.len() != 2 { + anyhow::bail!("Parse error: string_agg bad number of arguments"); + } + + + let expr_reg = program.alloc_register(); + let delimiter_reg = program.alloc_register(); + + let expr = &args[0]; + let delimiter_expr: ast::Expr; + + match &args[1] { + ast::Expr::Id(ident) => { + if ident.0.starts_with("\"") { + anyhow::bail!("Parse error: no such column: \",\" - should this be a string literal in single-quotes?"); + } else { + delimiter_expr = args[1].clone(); + } + }, + ast::Expr::Literal(Literal::String(s)) => { + delimiter_expr = ast::Expr::Literal(Literal::String(s.to_string())); + }, + _ => anyhow::bail!("Incorrect delimiter parameter"), + }; + + if let Err(error) = translate_expr(program, select, expr, expr_reg) { + anyhow::bail!(error); + } + if let Err(error) = translate_expr(program, select, &delimiter_expr, delimiter_reg) { + anyhow::bail!(error); + } + + program.emit_insn(Insn::AggStep { + acc_reg: target_register, + col: expr_reg, + delimiter: delimiter_reg, + func: AggFunc::StringAgg, + }); + + target_register + }, AggFunc::Sum => { if args.len() != 1 { anyhow::bail!("Parse error: sum bad number of arguments"); @@ -738,6 +827,7 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Sum, }); target_register @@ -752,6 +842,7 @@ fn translate_aggregation( program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, + delimiter: 0, func: AggFunc::Total, }); target_register diff --git a/core/types.rs b/core/types.rs index 5f99f1482..30fdf6d85 100644 --- a/core/types.rs +++ b/core/types.rs @@ -40,7 +40,7 @@ impl Display for OwnedValue { match self { OwnedValue::Null => write!(f, "NULL"), OwnedValue::Integer(i) => write!(f, "{}", i), - OwnedValue::Float(fl) => write!(f, "{}", fl), + OwnedValue::Float(fl) => write!(f, "{:?}", fl), OwnedValue::Text(s) => write!(f, "{}", s), OwnedValue::Blob(b) => write!(f, "{:?}", b), OwnedValue::Agg(a) => match a.as_ref() { @@ -49,6 +49,7 @@ impl Display for OwnedValue { AggContext::Count(count) => write!(f, "{}", count), AggContext::Max(max) => write!(f, "{}", max), AggContext::Min(min) => write!(f, "{}", min), + AggContext::GroupConcat(s) => write!(f, "{}", s), }, OwnedValue::Record(r) => write!(f, "{:?}", r), } @@ -62,6 +63,7 @@ pub enum AggContext { Count(OwnedValue), Max(OwnedValue), Min(OwnedValue), + GroupConcat(OwnedValue), } impl std::ops::Add for OwnedValue { @@ -81,6 +83,23 @@ impl std::ops::Add for OwnedValue { (OwnedValue::Float(float_left), OwnedValue::Float(float_right)) => { OwnedValue::Float(float_left + float_right) } + (OwnedValue::Text(string_left), OwnedValue::Text(string_right)) => { + OwnedValue::Text(Rc::new(string_left.to_string() + &string_right.to_string())) + } + (OwnedValue::Text(string_left), OwnedValue::Integer(int_right)) => { + OwnedValue::Text(Rc::new(string_left.to_string() + &int_right.to_string())) + } + (OwnedValue::Integer(int_left), OwnedValue::Text(string_right)) => { + OwnedValue::Text(Rc::new(int_left.to_string() + &string_right.to_string())) + } + (OwnedValue::Text(string_left), OwnedValue::Float(float_right)) => { + let string_right = OwnedValue::Float(float_right).to_string(); + OwnedValue::Text(Rc::new(string_left.to_string() + &string_right)) + } + (OwnedValue::Float(float_left), OwnedValue::Text(string_right)) => { + let string_left = OwnedValue::Float(float_left).to_string(); + OwnedValue::Text(Rc::new(string_left + &string_right.to_string())) + } (lhs, OwnedValue::Null) => lhs, (OwnedValue::Null, rhs) => rhs, _ => OwnedValue::Float(0.0), @@ -171,6 +190,7 @@ pub fn to_value(value: &OwnedValue) -> Value<'_> { AggContext::Count(count) => to_value(count), AggContext::Max(max) => to_value(max), AggContext::Min(min) => to_value(min), + AggContext::GroupConcat(s) => to_value(s), }, OwnedValue::Record(_) => todo!(), } diff --git a/core/vdbe.rs b/core/vdbe.rs index 33e09dd6e..5f9353fcb 100644 --- a/core/vdbe.rs +++ b/core/vdbe.rs @@ -169,6 +169,7 @@ pub enum Insn { AggStep { acc_reg: usize, col: usize, + delimiter: usize, func: AggFunc, }, @@ -800,7 +801,12 @@ impl Program { } _ => unreachable!("DecrJumpZero on non-integer register"), }, - Insn::AggStep { acc_reg, col, func } => { + Insn::AggStep { + acc_reg, + col, + delimiter, + func, + } => { if let OwnedValue::Null = &state.registers[*acc_reg] { state.registers[*acc_reg] = match func { AggFunc::Avg => OwnedValue::Agg(Box::new(AggContext::Avg( @@ -847,9 +853,10 @@ impl Program { } } } - _ => { - todo!(); - } + AggFunc::GroupConcat | + AggFunc::StringAgg => OwnedValue::Agg(Box::new( + AggContext::GroupConcat(OwnedValue::Text(Rc::new("".to_string()))), + )), }; } match func { @@ -950,8 +957,23 @@ impl Program { } } } - _ => { - todo!(); + AggFunc::GroupConcat | + AggFunc::StringAgg => { + let col = state.registers[*col].clone(); + let delimiter = state.registers[*delimiter].clone(); + let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut() + else { + unreachable!(); + }; + let AggContext::GroupConcat(acc) = agg.borrow_mut() else { + unreachable!(); + }; + if acc.to_string().len() == 0 { + *acc = col; + } else { + *acc += delimiter; + *acc += col; + } } }; state.pc += 1; @@ -970,9 +992,7 @@ impl Program { AggFunc::Count => {} AggFunc::Max => {} AggFunc::Min => {} - _ => { - todo!(); - } + AggFunc::GroupConcat | AggFunc::StringAgg => {} }; } OwnedValue::Null => { @@ -1369,7 +1389,12 @@ fn insn_to_str(addr: BranchOffset, insn: &Insn, indent: String) -> String { 0, "".to_string(), ), - Insn::AggStep { func, acc_reg, col } => ( + Insn::AggStep { + func, + acc_reg, + delimiter: _, + col, + } => ( "AggStep", 0, *col as i32, diff --git a/testing/all.test b/testing/all.test index 252ae03e9..215e3dc05 100755 --- a/testing/all.test +++ b/testing/all.test @@ -61,6 +61,26 @@ do_execsql_test select-min { SELECT min(age) FROM users; } {1} +do_execsql_test select-group-concat { + SELECT group_concat(name) FROM products; +} {hat,cap,shirt,sweater,sweatshirt,shorts,jeans,sneakers,boots,coat,accessories} + +do_execsql_test select-group-concat-with-delimiter { + SELECT group_concat(name, ';') FROM products; +} {hat;cap;shirt;sweater;sweatshirt;shorts;jeans;sneakers;boots;coat;accessories} + +do_execsql_test select-group-concat-with-column-delimiter { + SELECT group_concat(name, id) FROM products; +} {hat2cap3shirt4sweater5sweatshirt6shorts7jeans8sneakers9boots10coat11accessories} + +do_execsql_test select-string-agg-with-delimiter { + SELECT string_agg(name, ',') FROM products; +} {hat,cap,shirt,sweater,sweatshirt,shorts,jeans,sneakers,boots,coat,accessories} + +do_execsql_test select-string-agg-with-column-delimiter { + SELECT string_agg(name, id) FROM products; +} {hat2cap3shirt4sweater5sweatshirt6shorts7jeans8sneakers9boots10coat11accessories} + do_execsql_test select-limit-0 { SELECT id FROM users LIMIT 0; } {}