From bff3af632691044b4a5da39b2514f44f47af7a9c Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sat, 18 Jan 2025 11:18:12 +0530 Subject: [PATCH 01/12] Converted the unconditional unwrap to a match which handles the case when the function is COUNT and args are None and replaces the args. Solves https://github.com/tursodatabase/limbo/issues/725 --- core/translate/select.rs | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/core/translate/select.rs b/core/translate/select.rs index 768474a8e..56fc4a462 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -1,3 +1,5 @@ +use std::ops::Deref; + use super::emitter::emit_program; use super::expr::get_name; use super::plan::SelectQueryType; @@ -9,10 +11,10 @@ use crate::translate::planner::{ parse_where, resolve_aggregates, OperatorIdCounter, }; use crate::util::normalize_ident; -use crate::SymbolTable; use crate::{schema::Schema, vdbe::builder::ProgramBuilder, Result}; -use sqlite3_parser::ast; +use crate::{LimboError, SymbolTable}; use sqlite3_parser::ast::ResultColumn; +use sqlite3_parser::ast::{self, Expr}; pub fn translate_select( program: &mut ProgramBuilder, @@ -116,9 +118,25 @@ pub fn prepare_select_plan( args_count, ) { Ok(Func::Agg(f)) => { + let count_args = vec![ast::Expr::Literal( + ast::Literal::Numeric("1".to_string()), + )]; + let agg_args: Result, LimboError> = match args { + // if args is None and its COUNT + None if name.0.to_uppercase() == "COUNT" => { + Ok(count_args) + } + // if args is None and the function is not COUNT + None => crate::bail_parse_error!( + "Aggregate function {} requires arguments", + name.0 + ), + Some(args) => Ok(args.clone()), + }; + let agg = Aggregate { func: f, - args: args.as_ref().unwrap().clone(), + args: agg_args?.clone(), original_expr: expr.clone(), }; aggregate_expressions.push(agg.clone()); From 4c1f4d71d631b336d391e482a92f62412f4d1496 Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sat, 18 Jan 2025 11:31:34 +0530 Subject: [PATCH 02/12] Refactor code --- core/translate/select.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/translate/select.rs b/core/translate/select.rs index 56fc4a462..50f918856 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -118,12 +118,12 @@ pub fn prepare_select_plan( args_count, ) { Ok(Func::Agg(f)) => { - let count_args = vec![ast::Expr::Literal( - ast::Literal::Numeric("1".to_string()), - )]; let agg_args: Result, LimboError> = match args { // if args is None and its COUNT None if name.0.to_uppercase() == "COUNT" => { + let count_args = vec![ast::Expr::Literal( + ast::Literal::Numeric("1".to_string()), + )]; Ok(count_args) } // if args is None and the function is not COUNT From ed5c8ddcf086421a26adf1d94e4ff038cffd43ba Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sat, 18 Jan 2025 11:35:02 +0530 Subject: [PATCH 03/12] Remove unused import --- core/translate/select.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/translate/select.rs b/core/translate/select.rs index 50f918856..c37466a9b 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -1,5 +1,3 @@ -use std::ops::Deref; - use super::emitter::emit_program; use super::expr::get_name; use super::plan::SelectQueryType; From 776aecbce0c8a77d7351c0b17b6653ebf2976d70 Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sun, 19 Jan 2025 03:00:18 +0530 Subject: [PATCH 04/12] 1. All aggregate functions now validate their args and return Parse error for wrong number of args. 2. Unknown functions now return `no such function: unknown_function` --- core/function.rs | 83 ++++++++++++++++++++++++++++++----- core/translate/aggregation.rs | 8 +++- core/translate/group_by.rs | 8 +++- core/translate/select.rs | 40 ++++++++++------- core/vdbe/mod.rs | 14 ++++-- 5 files changed, 118 insertions(+), 35 deletions(-) diff --git a/core/function.rs b/core/function.rs index 85a142437..c0f6609ca 100644 --- a/core/function.rs +++ b/core/function.rs @@ -4,6 +4,8 @@ use std::rc::Rc; use limbo_ext::ScalarFunction; +use crate::LimboError; + pub struct ExternalFunc { pub name: String, pub func: ScalarFunction, @@ -67,6 +69,7 @@ impl Display for JsonFunc { pub enum AggFunc { Avg, Count, + Count0, GroupConcat, Max, Min, @@ -76,9 +79,24 @@ pub enum AggFunc { } impl AggFunc { + pub fn num_args(&self) -> usize { + match self { + Self::Avg => 1, + Self::Count0 => 0, + Self::Count => 1, + Self::GroupConcat => 1, + Self::Max => 1, + Self::Min => 1, + Self::StringAgg => 2, + Self::Sum => 1, + Self::Total => 1, + } + } + pub fn to_string(&self) -> &str { match self { Self::Avg => "avg", + Self::Count0 => "count", Self::Count => "count", Self::GroupConcat => "group_concat", Self::Max => "max", @@ -336,19 +354,64 @@ pub struct FuncCtx { } impl Func { - pub fn resolve_function(name: &str, arg_count: usize) -> Result { + pub fn resolve_function(name: &str, arg_count: usize) -> Result { match name { - "avg" => Ok(Self::Agg(AggFunc::Avg)), - "count" => Ok(Self::Agg(AggFunc::Count)), - "group_concat" => Ok(Self::Agg(AggFunc::GroupConcat)), - "max" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Max)), + "avg" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Avg)) + } + "count" => { + // Handle both COUNT() and COUNT(expr) cases + if arg_count == 0 { + Ok(Self::Agg(AggFunc::Count0)) // COUNT() case + } else if arg_count == 1 { + Ok(Self::Agg(AggFunc::Count)) // COUNT(expr) case + } else { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + } + "group_concat" => { + if arg_count != 1 && arg_count != 2 { + println!("{}", arg_count); + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::GroupConcat)) + } "max" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Max)), - "min" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Min)), + "max" => { + if arg_count < 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Max)) + } "min" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Min)), + "min" => { + if arg_count < 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Min)) + } "nullif" if arg_count == 2 => Ok(Self::Scalar(ScalarFunc::Nullif)), - "string_agg" => Ok(Self::Agg(AggFunc::StringAgg)), - "sum" => Ok(Self::Agg(AggFunc::Sum)), - "total" => Ok(Self::Agg(AggFunc::Total)), + "string_agg" => { + if arg_count != 2 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::StringAgg)) + } + "sum" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Sum)) + } + "total" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Total)) + } "char" => Ok(Self::Scalar(ScalarFunc::Char)), "coalesce" => Ok(Self::Scalar(ScalarFunc::Coalesce)), "concat" => Ok(Self::Scalar(ScalarFunc::Concat)), @@ -432,7 +495,7 @@ impl Func { "trunc" => Ok(Self::Math(MathFunc::Trunc)), #[cfg(not(target_family = "wasm"))] "load_extension" => Ok(Self::Scalar(ScalarFunc::LoadExtension)), - _ => Err(()), + _ => crate::bail_parse_error!("no such function: {}", name), } } } diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index c8a7a520a..1456f0fbf 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -74,7 +74,7 @@ pub fn translate_aggregation_step( }); target_register } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { let expr_reg = if agg.args.is_empty() { program.alloc_register() } else { @@ -87,7 +87,11 @@ pub fn translate_aggregation_step( acc_reg: target_register, col: expr_reg, delimiter: 0, - func: AggFunc::Count, + func: if matches!(agg.func, AggFunc::Count0) { + AggFunc::Count0 + } else { + AggFunc::Count + }, }); target_register } diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 5855caa06..024c6cfad 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -463,14 +463,18 @@ pub fn translate_aggregation_step_groupby( }); target_register } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { let expr_reg = program.alloc_register(); emit_column(program, expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, delimiter: 0, - func: AggFunc::Count, + func: if matches!(agg.func, AggFunc::Count0) { + AggFunc::Count0 + } else { + AggFunc::Count + }, }); target_register } diff --git a/core/translate/select.rs b/core/translate/select.rs index c37466a9b..ccb3d491f 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -116,21 +116,20 @@ pub fn prepare_select_plan( args_count, ) { Ok(Func::Agg(f)) => { - let agg_args: Result, LimboError> = match args { - // if args is None and its COUNT - None if name.0.to_uppercase() == "COUNT" => { - let count_args = vec![ast::Expr::Literal( - ast::Literal::Numeric("1".to_string()), - )]; - Ok(count_args) - } - // if args is None and the function is not COUNT - None => crate::bail_parse_error!( - "Aggregate function {} requires arguments", - name.0 - ), - Some(args) => Ok(args.clone()), - }; + let agg_args: Result, LimboError> = + match (args, &f) { + (None, crate::function::AggFunc::Count0) => { + // COUNT() case + Ok(vec![ast::Expr::Literal( + ast::Literal::Numeric("1".to_string()), + )]) + } + (None, _) => crate::bail_parse_error!( + "Aggregate function {} requires arguments", + name.0 + ), + (Some(args), _) => Ok(args.clone()), + }; let agg = Aggregate { func: f, @@ -163,8 +162,12 @@ pub fn prepare_select_plan( contains_aggregates, }); } - Err(_) => { - if syms.functions.contains_key(&name.0) { + Err(e) => { + // Only handle the "no such function" case specially + // All other errors should be propagated as-is + if e.to_string().starts_with("no such function: ") + && syms.functions.contains_key(&name.0) + { let contains_aggregates = resolve_aggregates( expr, &mut aggregate_expressions, @@ -179,7 +182,10 @@ pub fn prepare_select_plan( expr: expr.clone(), contains_aggregates, }); + continue; // Continue with the normal flow instead of returning } + // Propagate the original error + return Err(e); } } } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 43160eb40..992a9db6e 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -1202,7 +1202,7 @@ impl Program { // Total() never throws an integer overflow. OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Float(0.0)))) } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { OwnedValue::Agg(Box::new(AggContext::Count(OwnedValue::Integer(0)))) } AggFunc::Max => { @@ -1270,7 +1270,13 @@ impl Program { }; *acc += col; } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { + // println!("here"); + if matches!(&state.registers[*acc_reg], OwnedValue::Null) { + state.registers[*acc_reg] = OwnedValue::Agg(Box::new( + AggContext::Count(OwnedValue::Integer(0)), + )); + } let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut() else { unreachable!(); @@ -1397,7 +1403,7 @@ impl Program { *acc /= count.clone(); } AggFunc::Sum | AggFunc::Total => {} - AggFunc::Count => {} + AggFunc::Count | AggFunc::Count0 => {} AggFunc::Max => {} AggFunc::Min => {} AggFunc::GroupConcat | AggFunc::StringAgg => {} @@ -1409,7 +1415,7 @@ impl Program { AggFunc::Total => { state.registers[*register] = OwnedValue::Float(0.0); } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { state.registers[*register] = OwnedValue::Integer(0); } _ => {} From 68553904c76f47ac35ac477e0a8a4dfd4f42cb99 Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sat, 18 Jan 2025 11:18:12 +0530 Subject: [PATCH 05/12] Converted the unconditional unwrap to a match which handles the case when the function is COUNT and args are None and replaces the args. Solves https://github.com/tursodatabase/limbo/issues/725 --- core/translate/select.rs | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/core/translate/select.rs b/core/translate/select.rs index 157f55d9d..e75201d43 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -1,3 +1,5 @@ +use std::ops::Deref; + use super::emitter::emit_program; use super::expr::get_name; use super::plan::SelectQueryType; @@ -9,10 +11,10 @@ use crate::translate::planner::{ parse_where, resolve_aggregates, OperatorIdCounter, }; use crate::util::normalize_ident; -use crate::SymbolTable; use crate::{schema::Schema, vdbe::builder::ProgramBuilder, Result}; -use sqlite3_parser::ast; +use crate::{LimboError, SymbolTable}; use sqlite3_parser::ast::ResultColumn; +use sqlite3_parser::ast::{self, Expr}; pub fn translate_select( program: &mut ProgramBuilder, @@ -116,9 +118,25 @@ pub fn prepare_select_plan( args_count, ) { Ok(Func::Agg(f)) => { + let count_args = vec![ast::Expr::Literal( + ast::Literal::Numeric("1".to_string()), + )]; + let agg_args: Result, LimboError> = match args { + // if args is None and its COUNT + None if name.0.to_uppercase() == "COUNT" => { + Ok(count_args) + } + // if args is None and the function is not COUNT + None => crate::bail_parse_error!( + "Aggregate function {} requires arguments", + name.0 + ), + Some(args) => Ok(args.clone()), + }; + let agg = Aggregate { func: f, - args: args.as_ref().unwrap().clone(), + args: agg_args?.clone(), original_expr: expr.clone(), }; aggregate_expressions.push(agg.clone()); From 027803aabf667631bd5e0d4eb93cea5679fb4ddc Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sat, 18 Jan 2025 11:31:34 +0530 Subject: [PATCH 06/12] Refactor code --- core/translate/select.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/translate/select.rs b/core/translate/select.rs index e75201d43..e5d54b03a 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -118,12 +118,12 @@ pub fn prepare_select_plan( args_count, ) { Ok(Func::Agg(f)) => { - let count_args = vec![ast::Expr::Literal( - ast::Literal::Numeric("1".to_string()), - )]; let agg_args: Result, LimboError> = match args { // if args is None and its COUNT None if name.0.to_uppercase() == "COUNT" => { + let count_args = vec![ast::Expr::Literal( + ast::Literal::Numeric("1".to_string()), + )]; Ok(count_args) } // if args is None and the function is not COUNT From ca097b1972332ab054dafd043e9880be78073602 Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sat, 18 Jan 2025 11:35:02 +0530 Subject: [PATCH 07/12] Remove unused import --- core/translate/select.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/translate/select.rs b/core/translate/select.rs index e5d54b03a..26ede7a98 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -1,5 +1,3 @@ -use std::ops::Deref; - use super::emitter::emit_program; use super::expr::get_name; use super::plan::SelectQueryType; From 6173aeeb3b100e20f9f3f60e7265592aa7d47ed0 Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sun, 19 Jan 2025 03:00:18 +0530 Subject: [PATCH 08/12] 1. Fix merge conflicts 2. change tests for extensions to return error instead of null (Preston) --- core/function.rs | 86 ++++++++++++++++++++++++++++++----- core/translate/aggregation.rs | 8 +++- core/translate/group_by.rs | 8 +++- core/translate/select.rs | 36 ++++++++------- core/vdbe/mod.rs | 16 +++++-- 5 files changed, 117 insertions(+), 37 deletions(-) diff --git a/core/function.rs b/core/function.rs index 63a1170d9..674ff57be 100644 --- a/core/function.rs +++ b/core/function.rs @@ -3,6 +3,8 @@ use std::fmt; use std::fmt::{Debug, Display}; use std::rc::Rc; +use crate::LimboError; + pub struct ExternalFunc { pub name: String, pub func: ExtFunc, @@ -102,6 +104,7 @@ impl Display for JsonFunc { pub enum AggFunc { Avg, Count, + Count0, GroupConcat, Max, Min, @@ -129,9 +132,25 @@ impl PartialEq for AggFunc { } impl AggFunc { + pub fn num_args(&self) -> usize { + match self { + Self::Avg => 1, + Self::Count0 => 0, + Self::Count => 1, + Self::GroupConcat => 1, + Self::Max => 1, + Self::Min => 1, + Self::StringAgg => 2, + Self::Sum => 1, + Self::Total => 1, + Self::External(func) => func.agg_args().unwrap_or(0), + } + } + pub fn to_string(&self) -> &str { match self { Self::Avg => "avg", + Self::Count0 => "count", Self::Count => "count", Self::GroupConcat => "group_concat", Self::Max => "max", @@ -390,19 +409,64 @@ pub struct FuncCtx { } impl Func { - pub fn resolve_function(name: &str, arg_count: usize) -> Result { + pub fn resolve_function(name: &str, arg_count: usize) -> Result { match name { - "avg" => Ok(Self::Agg(AggFunc::Avg)), - "count" => Ok(Self::Agg(AggFunc::Count)), - "group_concat" => Ok(Self::Agg(AggFunc::GroupConcat)), - "max" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Max)), + "avg" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Avg)) + } + "count" => { + // Handle both COUNT() and COUNT(expr) cases + if arg_count == 0 { + Ok(Self::Agg(AggFunc::Count0)) // COUNT() case + } else if arg_count == 1 { + Ok(Self::Agg(AggFunc::Count)) // COUNT(expr) case + } else { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + } + "group_concat" => { + if arg_count != 1 && arg_count != 2 { + println!("{}", arg_count); + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::GroupConcat)) + } "max" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Max)), - "min" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Min)), + "max" => { + if arg_count < 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Max)) + } "min" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Min)), + "min" => { + if arg_count < 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Min)) + } "nullif" if arg_count == 2 => Ok(Self::Scalar(ScalarFunc::Nullif)), - "string_agg" => Ok(Self::Agg(AggFunc::StringAgg)), - "sum" => Ok(Self::Agg(AggFunc::Sum)), - "total" => Ok(Self::Agg(AggFunc::Total)), + "string_agg" => { + if arg_count != 2 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::StringAgg)) + } + "sum" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Sum)) + } + "total" => { + if arg_count != 1 { + crate::bail_parse_error!("wrong number of arguments to function {}()", name) + } + Ok(Self::Agg(AggFunc::Total)) + } "char" => Ok(Self::Scalar(ScalarFunc::Char)), "coalesce" => Ok(Self::Scalar(ScalarFunc::Coalesce)), "concat" => Ok(Self::Scalar(ScalarFunc::Concat)), @@ -486,7 +550,7 @@ impl Func { "trunc" => Ok(Self::Math(MathFunc::Trunc)), #[cfg(not(target_family = "wasm"))] "load_extension" => Ok(Self::Scalar(ScalarFunc::LoadExtension)), - _ => Err(()), + _ => crate::bail_parse_error!("no such function: {}", name), } } -} +} \ No newline at end of file diff --git a/core/translate/aggregation.rs b/core/translate/aggregation.rs index 3fc3c4dac..d23caf3ec 100644 --- a/core/translate/aggregation.rs +++ b/core/translate/aggregation.rs @@ -74,7 +74,7 @@ pub fn translate_aggregation_step( }); target_register } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { let expr_reg = if agg.args.is_empty() { program.alloc_register() } else { @@ -87,7 +87,11 @@ pub fn translate_aggregation_step( acc_reg: target_register, col: expr_reg, delimiter: 0, - func: AggFunc::Count, + func: if matches!(agg.func, AggFunc::Count0) { + AggFunc::Count0 + } else { + AggFunc::Count + }, }); target_register } diff --git a/core/translate/group_by.rs b/core/translate/group_by.rs index 244c82867..c58d5f56a 100644 --- a/core/translate/group_by.rs +++ b/core/translate/group_by.rs @@ -463,14 +463,18 @@ pub fn translate_aggregation_step_groupby( }); target_register } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { let expr_reg = program.alloc_register(); emit_column(program, expr_reg); program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, delimiter: 0, - func: AggFunc::Count, + func: if matches!(agg.func, AggFunc::Count0) { + AggFunc::Count0 + } else { + AggFunc::Count + }, }); target_register } diff --git a/core/translate/select.rs b/core/translate/select.rs index 26ede7a98..d2c2a8f27 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -116,21 +116,20 @@ pub fn prepare_select_plan( args_count, ) { Ok(Func::Agg(f)) => { - let agg_args: Result, LimboError> = match args { - // if args is None and its COUNT - None if name.0.to_uppercase() == "COUNT" => { - let count_args = vec![ast::Expr::Literal( - ast::Literal::Numeric("1".to_string()), - )]; - Ok(count_args) - } - // if args is None and the function is not COUNT - None => crate::bail_parse_error!( - "Aggregate function {} requires arguments", - name.0 - ), - Some(args) => Ok(args.clone()), - }; + let agg_args: Result, LimboError> = + match (args, &f) { + (None, crate::function::AggFunc::Count0) => { + // COUNT() case + Ok(vec![ast::Expr::Literal( + ast::Literal::Numeric("1".to_string()), + )]) + } + (None, _) => crate::bail_parse_error!( + "Aggregate function {} requires arguments", + name.0 + ), + (Some(args), _) => Ok(args.clone()), + }; let agg = Aggregate { func: f, @@ -163,7 +162,7 @@ pub fn prepare_select_plan( contains_aggregates, }); } - Err(_) => { + Err(e) => { if let Some(f) = syms.resolve_function(&name.0, args_count) { if let ExtFunc::Scalar(_) = f.as_ref().func { @@ -199,6 +198,9 @@ pub fn prepare_select_plan( contains_aggregates: true, }); } + continue; // Continue with the normal flow instead of returning + } else { + return Err(e); } } } @@ -333,4 +335,4 @@ pub fn prepare_select_plan( } _ => todo!(), } -} +} \ No newline at end of file diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 63716d87e..f6940ddd9 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -1204,7 +1204,7 @@ impl Program { // Total() never throws an integer overflow. OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Float(0.0)))) } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { OwnedValue::Agg(Box::new(AggContext::Count(OwnedValue::Integer(0)))) } AggFunc::Max => { @@ -1289,7 +1289,13 @@ impl Program { }; *acc += col; } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { + // println!("here"); + if matches!(&state.registers[*acc_reg], OwnedValue::Null) { + state.registers[*acc_reg] = OwnedValue::Agg(Box::new( + AggContext::Count(OwnedValue::Integer(0)), + )); + } let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut() else { unreachable!(); @@ -1437,7 +1443,7 @@ impl Program { *acc /= count.clone(); } AggFunc::Sum | AggFunc::Total => {} - AggFunc::Count => {} + AggFunc::Count | AggFunc::Count0 => {} AggFunc::Max => {} AggFunc::Min => {} AggFunc::GroupConcat | AggFunc::StringAgg => {} @@ -1451,7 +1457,7 @@ impl Program { AggFunc::Total => { state.registers[*register] = OwnedValue::Float(0.0); } - AggFunc::Count => { + AggFunc::Count | AggFunc::Count0 => { state.registers[*register] = OwnedValue::Integer(0); } _ => {} @@ -4209,4 +4215,4 @@ mod tests { expected_str ); } -} +} \ No newline at end of file From fa0503f0ce35b32844da449a3aee8f0458595999 Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sun, 19 Jan 2025 04:58:05 +0530 Subject: [PATCH 09/12] 1. Changes to extension.py 2. chore: cargo fmt --- core/function.rs | 2 +- core/translate/select.rs | 2 +- core/vdbe/mod.rs | 2 +- testing/extensions.py | 16 ++++++++++------ 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/core/function.rs b/core/function.rs index 674ff57be..79065ab39 100644 --- a/core/function.rs +++ b/core/function.rs @@ -553,4 +553,4 @@ impl Func { _ => crate::bail_parse_error!("no such function: {}", name), } } -} \ No newline at end of file +} diff --git a/core/translate/select.rs b/core/translate/select.rs index d2c2a8f27..8569f8135 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -335,4 +335,4 @@ pub fn prepare_select_plan( } _ => todo!(), } -} \ No newline at end of file +} diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index f6940ddd9..291c1b427 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -4215,4 +4215,4 @@ mod tests { expected_str ); } -} \ No newline at end of file +} diff --git a/testing/extensions.py b/testing/extensions.py index 61255e804..9dcd27846 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -115,8 +115,12 @@ def validate_string_uuid(result): return len(result) == 36 and result.count("-") == 4 +def returns_error(result): + return "error: no such function: " in result + + def returns_null(result): - return result == "" or result == b"\n" or result == b"" + return result == "" or result == "\n" def assert_now_unixtime(result): @@ -135,10 +139,10 @@ def test_uuid(pipe): run_test( pipe, "SELECT uuid4();", - returns_null, + returns_error, "uuid functions return null when ext not loaded", ) - run_test(pipe, "SELECT uuid4_str();", returns_null) + run_test(pipe, "SELECT uuid4_str();", returns_error) run_test( pipe, f".load {extension_path}", @@ -178,7 +182,7 @@ def test_regexp(pipe): extension_path = "./target/debug/liblimbo_regexp.so" # before extension loads, assert no function - run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_null) + run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_error) run_test(pipe, f".load {extension_path}", returns_null) print(f"Extension {extension_path} loaded successfully.") run_test(pipe, "SELECT regexp('a.c', 'abc');", validate_true) @@ -225,7 +229,7 @@ def test_aggregates(pipe): run_test( pipe, "SELECT median(1);", - returns_null, + returns_error, "median agg function returns null when ext not loaded", ) run_test( @@ -282,4 +286,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file From 870a2ea802155e898361736f611f571eae8500c3 Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sun, 19 Jan 2025 06:40:40 +0530 Subject: [PATCH 10/12] clean up --- core/vdbe/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 291c1b427..e0798df5d 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -1290,7 +1290,6 @@ impl Program { *acc += col; } AggFunc::Count | AggFunc::Count0 => { - // println!("here"); if matches!(&state.registers[*acc_reg], OwnedValue::Null) { state.registers[*acc_reg] = OwnedValue::Agg(Box::new( AggContext::Count(OwnedValue::Integer(0)), From acad562c0710104acd0901ce869cc6f3b054b956 Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sun, 19 Jan 2025 06:50:00 +0530 Subject: [PATCH 11/12] Remove unnecessary Result --- core/translate/select.rs | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/core/translate/select.rs b/core/translate/select.rs index 8569f8135..cd8a8bfc5 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -116,24 +116,23 @@ pub fn prepare_select_plan( args_count, ) { Ok(Func::Agg(f)) => { - let agg_args: Result, LimboError> = - match (args, &f) { - (None, crate::function::AggFunc::Count0) => { - // COUNT() case - Ok(vec![ast::Expr::Literal( - ast::Literal::Numeric("1".to_string()), - )]) - } - (None, _) => crate::bail_parse_error!( - "Aggregate function {} requires arguments", - name.0 - ), - (Some(args), _) => Ok(args.clone()), - }; + let agg_args = match (args, &f) { + (None, crate::function::AggFunc::Count0) => { + // COUNT() case + vec![ast::Expr::Literal(ast::Literal::Numeric( + "1".to_string(), + ))] + } + (None, _) => crate::bail_parse_error!( + "Aggregate function {} requires arguments", + name.0 + ), + (Some(args), _) => args.clone(), + }; let agg = Aggregate { func: f, - args: agg_args?.clone(), + args: agg_args.clone(), original_expr: expr.clone(), }; aggregate_expressions.push(agg.clone()); From 5cf78b7d54e3a2a810e5d639743830da9a482dbe Mon Sep 17 00:00:00 2001 From: Krishna Vishal Date: Sun, 19 Jan 2025 07:18:31 +0530 Subject: [PATCH 12/12] chore: clippy remove unused imports --- core/translate/select.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/translate/select.rs b/core/translate/select.rs index cd8a8bfc5..35c522494 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -10,9 +10,9 @@ use crate::translate::planner::{ }; use crate::util::normalize_ident; use crate::{schema::Schema, vdbe::builder::ProgramBuilder, Result}; -use crate::{LimboError, SymbolTable}; +use crate::SymbolTable; use sqlite3_parser::ast::ResultColumn; -use sqlite3_parser::ast::{self, Expr}; +use sqlite3_parser::ast::{self}; pub fn translate_select( program: &mut ProgramBuilder,