diff --git a/core/function.rs b/core/function.rs index 63a1170d9..79065ab39 100644 --- a/core/function.rs +++ b/core/function.rs @@ -3,6 +3,8 @@ use std::fmt; use std::fmt::{Debug, Display}; use std::rc::Rc; +use crate::LimboError; + pub struct ExternalFunc { pub name: String, pub func: ExtFunc, @@ -102,6 +104,7 @@ impl Display for JsonFunc { pub enum AggFunc { Avg, Count, + Count0, GroupConcat, Max, Min, @@ -129,9 +132,25 @@ impl PartialEq for AggFunc { } impl AggFunc { + pub fn num_args(&self) -> usize { + match self { + Self::Avg => 1, + Self::Count0 => 0, + Self::Count => 1, + Self::GroupConcat => 1, + Self::Max => 1, + Self::Min => 1, + Self::StringAgg => 2, + Self::Sum => 1, + Self::Total => 1, + Self::External(func) => func.agg_args().unwrap_or(0), + } + } + pub fn to_string(&self) -> &str { match self { Self::Avg => "avg", + Self::Count0 => "count", Self::Count => "count", Self::GroupConcat => "group_concat", Self::Max => "max", @@ -390,19 +409,64 @@ pub struct FuncCtx { } impl Func { - pub fn resolve_function(name: &str, arg_count: usize) -> Result { + pub fn resolve_function(name: &str, arg_count: usize) -> Result { match name { - "avg" => Ok(Self::Agg(AggFunc::Avg)), - "count" => Ok(Self::Agg(AggFunc::Count)), - "group_concat" => Ok(Self::Agg(AggFunc::GroupConcat)), - "max" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Max)), + "avg" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Avg)) + } + "count" => { + // Handle both COUNT() and COUNT(expr) cases + if arg_count == 0 { + Ok(Self::Agg(AggFunc::Count0)) // COUNT() case + } else if arg_count == 1 { + Ok(Self::Agg(AggFunc::Count)) // COUNT(expr) case + } else { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + } + "group_concat" => { + if arg_count != 1 && arg_count != 2 { + println!("{}", arg_count); + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::GroupConcat)) + } "max" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Max)), - "min" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Min)), + "max" => { + if arg_count < 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Max)) + } "min" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Min)), + "min" => { + if arg_count < 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Min)) + } "nullif" if arg_count == 2 => Ok(Self::Scalar(ScalarFunc::Nullif)), - "string_agg" => Ok(Self::Agg(AggFunc::StringAgg)), - "sum" => Ok(Self::Agg(AggFunc::Sum)), - "total" => Ok(Self::Agg(AggFunc::Total)), + "string_agg" => { + if arg_count != 2 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::StringAgg)) + } + "sum" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Sum)) + } + "total" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Total)) + } "char" => Ok(Self::Scalar(ScalarFunc::Char)), "coalesce" => Ok(Self::Scalar(ScalarFunc::Coalesce)), "concat" => Ok(Self::Scalar(ScalarFunc::Concat)), @@ -486,7 +550,7 @@ impl Func { "trunc" => Ok(Self::Math(MathFunc::Trunc)), #[cfg(not(target_family = "wasm"))] "load_extension" => Ok(Self::Scalar(ScalarFunc::LoadExtension)), - _ => Err(()), + _ => crate::bail_parse_error!("no such function: {}", name), } } } diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index 3fc3c4dac..d23caf3ec 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -74,7 +74,7 @@ pub fn translate_aggregation_step( }); target_register } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { let expr_reg = if agg.args.is_empty() { program.alloc_register() } else { @@ -87,7 +87,11 @@ pub fn translate_aggregation_step( acc_reg: target_register, col: expr_reg, delimiter: 0, - func: AggFunc::Count, + func: if matches!(agg.func, AggFunc::Count0) { + AggFunc::Count0 + } else { + AggFunc::Count + }, }); target_register } diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 244c82867..c58d5f56a 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -463,14 +463,18 @@ pub fn translate_aggregation_step_groupby( }); target_register } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { let expr_reg = program.alloc_register(); emit_column(program, expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, delimiter: 0, - func: AggFunc::Count, + func: if matches!(agg.func, AggFunc::Count0) { + AggFunc::Count0 + } else { + AggFunc::Count + }, }); target_register } diff --git a/core/translate/select.rs b/core/translate/select.rs index 157f55d9d..35c522494 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -9,10 +9,10 @@ use crate::translate::planner::{ parse_where, resolve_aggregates, OperatorIdCounter, }; use crate::util::normalize_ident; -use crate::SymbolTable; use crate::{schema::Schema, vdbe::builder::ProgramBuilder, Result}; -use sqlite3_parser::ast; +use crate::SymbolTable; use sqlite3_parser::ast::ResultColumn; +use sqlite3_parser::ast::{self}; pub fn translate_select( program: &mut ProgramBuilder, @@ -116,9 +116,23 @@ pub fn prepare_select_plan( args_count, ) { Ok(Func::Agg(f)) => { + let agg_args = match (args, &f) { + (None, crate::function::AggFunc::Count0) => { + // COUNT() case + vec![ast::Expr::Literal(ast::Literal::Numeric( + "1".to_string(), + ))] + } + (None, _) => crate::bail_parse_error!( + "Aggregate function {} requires arguments", + name.0 + ), + (Some(args), _) => args.clone(), + }; + let agg = Aggregate { func: f, - args: args.as_ref().unwrap().clone(), + args: agg_args.clone(), original_expr: expr.clone(), }; aggregate_expressions.push(agg.clone()); @@ -147,7 +161,7 @@ pub fn prepare_select_plan( contains_aggregates, }); } - Err(_) => { + Err(e) => { if let Some(f) = syms.resolve_function(&name.0, args_count) { if let ExtFunc::Scalar(_) = f.as_ref().func { @@ -183,6 +197,9 @@ pub fn prepare_select_plan( contains_aggregates: true, }); } + continue; // Continue with the normal flow instead of returning + } else { + return Err(e); } } } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 63716d87e..e0798df5d 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -1204,7 +1204,7 @@ impl Program { // Total() never throws an integer overflow. OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Float(0.0)))) } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { OwnedValue::Agg(Box::new(AggContext::Count(OwnedValue::Integer(0)))) } AggFunc::Max => { @@ -1289,7 +1289,12 @@ impl Program { }; *acc += col; } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { + if matches!(&state.registers[*acc_reg], OwnedValue::Null) { + state.registers[*acc_reg] = OwnedValue::Agg(Box::new( + AggContext::Count(OwnedValue::Integer(0)), + )); + } let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut() else { unreachable!(); @@ -1437,7 +1442,7 @@ impl Program { *acc /= count.clone(); } AggFunc::Sum | AggFunc::Total => {} - AggFunc::Count => {} + AggFunc::Count | AggFunc::Count0 => {} AggFunc::Max => {} AggFunc::Min => {} AggFunc::GroupConcat | AggFunc::StringAgg => {} @@ -1451,7 +1456,7 @@ impl Program { AggFunc::Total => { state.registers[*register] = OwnedValue::Float(0.0); } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { state.registers[*register] = OwnedValue::Integer(0); } _ => {} diff --git a/testing/extensions.py b/testing/extensions.py index 61255e804..9dcd27846 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -115,8 +115,12 @@ def validate_string_uuid(result): return len(result) == 36 and result.count("-") == 4 +def returns_error(result): + return "error: no such function: " in result + + def returns_null(result): - return result == "" or result == b"\n" or result == b"" + return result == "" or result == "\n" def assert_now_unixtime(result): @@ -135,10 +139,10 @@ def test_uuid(pipe): run_test( pipe, "SELECT uuid4();", - returns_null, + returns_error, "uuid functions return null when ext not loaded", ) - run_test(pipe, "SELECT uuid4_str();", returns_null) + run_test(pipe, "SELECT uuid4_str();", returns_error) run_test( pipe, f".load {extension_path}", @@ -178,7 +182,7 @@ def test_regexp(pipe): extension_path = "./target/debug/liblimbo_regexp.so" # before extension loads, assert no function - run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_null) + run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_error) run_test(pipe, f".load {extension_path}", returns_null) print(f"Extension {extension_path} loaded successfully.") run_test(pipe, "SELECT regexp('a.c', 'abc');", validate_true) @@ -225,7 +229,7 @@ def test_aggregates(pipe): run_test( pipe, "SELECT median(1);", - returns_null, + returns_error, "median agg function returns null when ext not loaded", ) run_test( @@ -282,4 +286,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file