mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-25 12:04:21 +01:00
Merge 'Fix panics on invalid aggregate function arguments' from Krishna Vishal
This pull request fixes panics when there's an aggregate function call with too few or too many arguments. Closes #726
This commit is contained in:
@@ -3,6 +3,8 @@ use std::fmt;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::rc::Rc;
|
||||
|
||||
use crate::LimboError;
|
||||
|
||||
pub struct ExternalFunc {
|
||||
pub name: String,
|
||||
pub func: ExtFunc,
|
||||
@@ -102,6 +104,7 @@ impl Display for JsonFunc {
|
||||
pub enum AggFunc {
|
||||
Avg,
|
||||
Count,
|
||||
Count0,
|
||||
GroupConcat,
|
||||
Max,
|
||||
Min,
|
||||
@@ -129,9 +132,25 @@ impl PartialEq for AggFunc {
|
||||
}
|
||||
|
||||
impl AggFunc {
|
||||
pub fn num_args(&self) -> usize {
|
||||
match self {
|
||||
Self::Avg => 1,
|
||||
Self::Count0 => 0,
|
||||
Self::Count => 1,
|
||||
Self::GroupConcat => 1,
|
||||
Self::Max => 1,
|
||||
Self::Min => 1,
|
||||
Self::StringAgg => 2,
|
||||
Self::Sum => 1,
|
||||
Self::Total => 1,
|
||||
Self::External(func) => func.agg_args().unwrap_or(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> &str {
|
||||
match self {
|
||||
Self::Avg => "avg",
|
||||
Self::Count0 => "count",
|
||||
Self::Count => "count",
|
||||
Self::GroupConcat => "group_concat",
|
||||
Self::Max => "max",
|
||||
@@ -390,19 +409,64 @@ pub struct FuncCtx {
|
||||
}
|
||||
|
||||
impl Func {
|
||||
pub fn resolve_function(name: &str, arg_count: usize) -> Result<Self, ()> {
|
||||
pub fn resolve_function(name: &str, arg_count: usize) -> Result<Self, LimboError> {
|
||||
match name {
|
||||
"avg" => Ok(Self::Agg(AggFunc::Avg)),
|
||||
"count" => Ok(Self::Agg(AggFunc::Count)),
|
||||
"group_concat" => Ok(Self::Agg(AggFunc::GroupConcat)),
|
||||
"max" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Max)),
|
||||
"avg" => {
|
||||
if arg_count != 1 {
|
||||
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
|
||||
}
|
||||
Ok(Self::Agg(AggFunc::Avg))
|
||||
}
|
||||
"count" => {
|
||||
// Handle both COUNT() and COUNT(expr) cases
|
||||
if arg_count == 0 {
|
||||
Ok(Self::Agg(AggFunc::Count0)) // COUNT() case
|
||||
} else if arg_count == 1 {
|
||||
Ok(Self::Agg(AggFunc::Count)) // COUNT(expr) case
|
||||
} else {
|
||||
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
|
||||
}
|
||||
}
|
||||
"group_concat" => {
|
||||
if arg_count != 1 && arg_count != 2 {
|
||||
println!("{}", arg_count);
|
||||
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
|
||||
}
|
||||
Ok(Self::Agg(AggFunc::GroupConcat))
|
||||
}
|
||||
"max" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Max)),
|
||||
"min" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Min)),
|
||||
"max" => {
|
||||
if arg_count < 1 {
|
||||
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
|
||||
}
|
||||
Ok(Self::Agg(AggFunc::Max))
|
||||
}
|
||||
"min" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Min)),
|
||||
"min" => {
|
||||
if arg_count < 1 {
|
||||
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
|
||||
}
|
||||
Ok(Self::Agg(AggFunc::Min))
|
||||
}
|
||||
"nullif" if arg_count == 2 => Ok(Self::Scalar(ScalarFunc::Nullif)),
|
||||
"string_agg" => Ok(Self::Agg(AggFunc::StringAgg)),
|
||||
"sum" => Ok(Self::Agg(AggFunc::Sum)),
|
||||
"total" => Ok(Self::Agg(AggFunc::Total)),
|
||||
"string_agg" => {
|
||||
if arg_count != 2 {
|
||||
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
|
||||
}
|
||||
Ok(Self::Agg(AggFunc::StringAgg))
|
||||
}
|
||||
"sum" => {
|
||||
if arg_count != 1 {
|
||||
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
|
||||
}
|
||||
Ok(Self::Agg(AggFunc::Sum))
|
||||
}
|
||||
"total" => {
|
||||
if arg_count != 1 {
|
||||
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
|
||||
}
|
||||
Ok(Self::Agg(AggFunc::Total))
|
||||
}
|
||||
"char" => Ok(Self::Scalar(ScalarFunc::Char)),
|
||||
"coalesce" => Ok(Self::Scalar(ScalarFunc::Coalesce)),
|
||||
"concat" => Ok(Self::Scalar(ScalarFunc::Concat)),
|
||||
@@ -486,7 +550,7 @@ impl Func {
|
||||
"trunc" => Ok(Self::Math(MathFunc::Trunc)),
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
"load_extension" => Ok(Self::Scalar(ScalarFunc::LoadExtension)),
|
||||
_ => Err(()),
|
||||
_ => crate::bail_parse_error!("no such function: {}", name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ pub fn translate_aggregation_step(
|
||||
});
|
||||
target_register
|
||||
}
|
||||
AggFunc::Count => {
|
||||
AggFunc::Count | AggFunc::Count0 => {
|
||||
let expr_reg = if agg.args.is_empty() {
|
||||
program.alloc_register()
|
||||
} else {
|
||||
@@ -87,7 +87,11 @@ pub fn translate_aggregation_step(
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: 0,
|
||||
func: AggFunc::Count,
|
||||
func: if matches!(agg.func, AggFunc::Count0) {
|
||||
AggFunc::Count0
|
||||
} else {
|
||||
AggFunc::Count
|
||||
},
|
||||
});
|
||||
target_register
|
||||
}
|
||||
|
||||
@@ -463,14 +463,18 @@ pub fn translate_aggregation_step_groupby(
|
||||
});
|
||||
target_register
|
||||
}
|
||||
AggFunc::Count => {
|
||||
AggFunc::Count | AggFunc::Count0 => {
|
||||
let expr_reg = program.alloc_register();
|
||||
emit_column(program, expr_reg);
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: 0,
|
||||
func: AggFunc::Count,
|
||||
func: if matches!(agg.func, AggFunc::Count0) {
|
||||
AggFunc::Count0
|
||||
} else {
|
||||
AggFunc::Count
|
||||
},
|
||||
});
|
||||
target_register
|
||||
}
|
||||
|
||||
@@ -9,10 +9,10 @@ use crate::translate::planner::{
|
||||
parse_where, resolve_aggregates, OperatorIdCounter,
|
||||
};
|
||||
use crate::util::normalize_ident;
|
||||
use crate::SymbolTable;
|
||||
use crate::{schema::Schema, vdbe::builder::ProgramBuilder, Result};
|
||||
use sqlite3_parser::ast;
|
||||
use crate::SymbolTable;
|
||||
use sqlite3_parser::ast::ResultColumn;
|
||||
use sqlite3_parser::ast::{self};
|
||||
|
||||
pub fn translate_select(
|
||||
program: &mut ProgramBuilder,
|
||||
@@ -116,9 +116,23 @@ pub fn prepare_select_plan(
|
||||
args_count,
|
||||
) {
|
||||
Ok(Func::Agg(f)) => {
|
||||
let agg_args = match (args, &f) {
|
||||
(None, crate::function::AggFunc::Count0) => {
|
||||
// COUNT() case
|
||||
vec![ast::Expr::Literal(ast::Literal::Numeric(
|
||||
"1".to_string(),
|
||||
))]
|
||||
}
|
||||
(None, _) => crate::bail_parse_error!(
|
||||
"Aggregate function {} requires arguments",
|
||||
name.0
|
||||
),
|
||||
(Some(args), _) => args.clone(),
|
||||
};
|
||||
|
||||
let agg = Aggregate {
|
||||
func: f,
|
||||
args: args.as_ref().unwrap().clone(),
|
||||
args: agg_args.clone(),
|
||||
original_expr: expr.clone(),
|
||||
};
|
||||
aggregate_expressions.push(agg.clone());
|
||||
@@ -147,7 +161,7 @@ pub fn prepare_select_plan(
|
||||
contains_aggregates,
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
Err(e) => {
|
||||
if let Some(f) = syms.resolve_function(&name.0, args_count)
|
||||
{
|
||||
if let ExtFunc::Scalar(_) = f.as_ref().func {
|
||||
@@ -183,6 +197,9 @@ pub fn prepare_select_plan(
|
||||
contains_aggregates: true,
|
||||
});
|
||||
}
|
||||
continue; // Continue with the normal flow instead of returning
|
||||
} else {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1204,7 +1204,7 @@ impl Program {
|
||||
// Total() never throws an integer overflow.
|
||||
OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Float(0.0))))
|
||||
}
|
||||
AggFunc::Count => {
|
||||
AggFunc::Count | AggFunc::Count0 => {
|
||||
OwnedValue::Agg(Box::new(AggContext::Count(OwnedValue::Integer(0))))
|
||||
}
|
||||
AggFunc::Max => {
|
||||
@@ -1289,7 +1289,12 @@ impl Program {
|
||||
};
|
||||
*acc += col;
|
||||
}
|
||||
AggFunc::Count => {
|
||||
AggFunc::Count | AggFunc::Count0 => {
|
||||
if matches!(&state.registers[*acc_reg], OwnedValue::Null) {
|
||||
state.registers[*acc_reg] = OwnedValue::Agg(Box::new(
|
||||
AggContext::Count(OwnedValue::Integer(0)),
|
||||
));
|
||||
}
|
||||
let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut()
|
||||
else {
|
||||
unreachable!();
|
||||
@@ -1437,7 +1442,7 @@ impl Program {
|
||||
*acc /= count.clone();
|
||||
}
|
||||
AggFunc::Sum | AggFunc::Total => {}
|
||||
AggFunc::Count => {}
|
||||
AggFunc::Count | AggFunc::Count0 => {}
|
||||
AggFunc::Max => {}
|
||||
AggFunc::Min => {}
|
||||
AggFunc::GroupConcat | AggFunc::StringAgg => {}
|
||||
@@ -1451,7 +1456,7 @@ impl Program {
|
||||
AggFunc::Total => {
|
||||
state.registers[*register] = OwnedValue::Float(0.0);
|
||||
}
|
||||
AggFunc::Count => {
|
||||
AggFunc::Count | AggFunc::Count0 => {
|
||||
state.registers[*register] = OwnedValue::Integer(0);
|
||||
}
|
||||
_ => {}
|
||||
|
||||
@@ -115,8 +115,12 @@ def validate_string_uuid(result):
|
||||
return len(result) == 36 and result.count("-") == 4
|
||||
|
||||
|
||||
def returns_error(result):
|
||||
return "error: no such function: " in result
|
||||
|
||||
|
||||
def returns_null(result):
|
||||
return result == "" or result == b"\n" or result == b""
|
||||
return result == "" or result == "\n"
|
||||
|
||||
|
||||
def assert_now_unixtime(result):
|
||||
@@ -135,10 +139,10 @@ def test_uuid(pipe):
|
||||
run_test(
|
||||
pipe,
|
||||
"SELECT uuid4();",
|
||||
returns_null,
|
||||
returns_error,
|
||||
"uuid functions return null when ext not loaded",
|
||||
)
|
||||
run_test(pipe, "SELECT uuid4_str();", returns_null)
|
||||
run_test(pipe, "SELECT uuid4_str();", returns_error)
|
||||
run_test(
|
||||
pipe,
|
||||
f".load {extension_path}",
|
||||
@@ -178,7 +182,7 @@ def test_regexp(pipe):
|
||||
extension_path = "./target/debug/liblimbo_regexp.so"
|
||||
|
||||
# before extension loads, assert no function
|
||||
run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_null)
|
||||
run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_error)
|
||||
run_test(pipe, f".load {extension_path}", returns_null)
|
||||
print(f"Extension {extension_path} loaded successfully.")
|
||||
run_test(pipe, "SELECT regexp('a.c', 'abc');", validate_true)
|
||||
@@ -225,7 +229,7 @@ def test_aggregates(pipe):
|
||||
run_test(
|
||||
pipe,
|
||||
"SELECT median(1);",
|
||||
returns_null,
|
||||
returns_error,
|
||||
"median agg function returns null when ext not loaded",
|
||||
)
|
||||
run_test(
|
||||
@@ -282,4 +286,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
Reference in New Issue
Block a user