mirror of
https://github.com/aljazceru/turso.git
synced 2025-12-18 17:14:20 +01:00
Support external aggregate functions in GROUP BY
This commit is contained in:
@@ -12,7 +12,7 @@ use crate::{
|
||||
insn::Insn,
|
||||
BranchOffset,
|
||||
},
|
||||
Result,
|
||||
LimboError, Result,
|
||||
};
|
||||
use crate::translate::aggregation::emit_collseq_if_needed;
|
||||
use super::{
|
||||
@@ -1119,8 +1119,34 @@ pub fn translate_aggregation_step_groupby(
|
||||
});
|
||||
target_register
|
||||
}
|
||||
AggFunc::External(_) => {
|
||||
todo!("External aggregate functions are not yet supported in GROUP BY");
|
||||
AggFunc::External(ref func) => {
|
||||
let argc = func.agg_args().map_err(|_| {
|
||||
LimboError::ExtensionError(
|
||||
"External aggregate function called with wrong number of arguments".to_string(),
|
||||
)
|
||||
})?;
|
||||
if argc != num_args {
|
||||
crate::bail_parse_error!(
|
||||
"External aggregate function called with wrong number of arguments"
|
||||
);
|
||||
}
|
||||
let expr_reg = agg_arg_source.translate(program, 0)?;
|
||||
for i in 0..argc {
|
||||
if i != 0 {
|
||||
let _ = agg_arg_source.translate(program, i)?;
|
||||
}
|
||||
// invariant: distinct aggregates are only supported for single-argument functions
|
||||
if argc == 1 {
|
||||
handle_distinct(program, agg_arg_source.aggregate(), expr_reg + i);
|
||||
}
|
||||
}
|
||||
program.emit_insn(Insn::AggStep {
|
||||
acc_reg: target_register,
|
||||
col: expr_reg,
|
||||
delimiter: 0,
|
||||
func: AggFunc::External(func.clone()),
|
||||
});
|
||||
target_register
|
||||
}
|
||||
};
|
||||
Ok(dest)
|
||||
|
||||
@@ -7,22 +7,22 @@ from cli_tests.test_turso_cli import TestTursoShell
|
||||
sqlite_exec = "./scripts/limbo-sqlite3"
|
||||
sqlite_flags = os.getenv("SQLITE_FLAGS", "-q").split(" ")
|
||||
|
||||
test_data = """CREATE TABLE numbers ( id INTEGER PRIMARY KEY, value FLOAT NOT NULL);
|
||||
INSERT INTO numbers (value) VALUES (1.0);
|
||||
INSERT INTO numbers (value) VALUES (2.0);
|
||||
INSERT INTO numbers (value) VALUES (3.0);
|
||||
INSERT INTO numbers (value) VALUES (4.0);
|
||||
INSERT INTO numbers (value) VALUES (5.0);
|
||||
INSERT INTO numbers (value) VALUES (6.0);
|
||||
INSERT INTO numbers (value) VALUES (7.0);
|
||||
CREATE TABLE test (value REAL, percent REAL);
|
||||
INSERT INTO test values (10, 25);
|
||||
INSERT INTO test values (20, 25);
|
||||
INSERT INTO test values (30, 25);
|
||||
INSERT INTO test values (40, 25);
|
||||
INSERT INTO test values (50, 25);
|
||||
INSERT INTO test values (60, 25);
|
||||
INSERT INTO test values (70, 25);
|
||||
test_data = """CREATE TABLE numbers ( id INTEGER PRIMARY KEY, value FLOAT NOT NULL, category TEXT DEFAULT 'A');
|
||||
INSERT INTO numbers (value, category) VALUES (1.0, 'A');
|
||||
INSERT INTO numbers (value, category) VALUES (2.0, 'A');
|
||||
INSERT INTO numbers (value, category) VALUES (3.0, 'A');
|
||||
INSERT INTO numbers (value, category) VALUES (4.0, 'B');
|
||||
INSERT INTO numbers (value, category) VALUES (5.0, 'B');
|
||||
INSERT INTO numbers (value, category) VALUES (6.0, 'B');
|
||||
INSERT INTO numbers (value, category) VALUES (7.0, 'B');
|
||||
CREATE TABLE test (value REAL, percent REAL, category TEXT);
|
||||
INSERT INTO test values (10, 25, 'A');
|
||||
INSERT INTO test values (20, 25, 'A');
|
||||
INSERT INTO test values (30, 25, 'B');
|
||||
INSERT INTO test values (40, 25, 'C');
|
||||
INSERT INTO test values (50, 25, 'C');
|
||||
INSERT INTO test values (60, 25, 'C');
|
||||
INSERT INTO test values (70, 25, 'D');
|
||||
"""
|
||||
|
||||
|
||||
@@ -174,6 +174,39 @@ def test_aggregates():
|
||||
limbo.quit()
|
||||
|
||||
|
||||
def test_grouped_aggregates():
|
||||
limbo = TestTursoShell(init_commands=test_data)
|
||||
extension_path = "./target/debug/liblimbo_percentile"
|
||||
limbo.execute_dot(f".load {extension_path}")
|
||||
|
||||
limbo.run_test_fn(
|
||||
"SELECT median(value) FROM numbers GROUP BY category;",
|
||||
lambda res: "2.0\n5.5" == res,
|
||||
"median aggregate function works",
|
||||
)
|
||||
limbo.run_test_fn(
|
||||
"SELECT percentile(value, percent) FROM test GROUP BY category;",
|
||||
lambda res: "12.5\n30.0\n45.0\n70.0" == res,
|
||||
"grouped aggregate percentile function with 2 arguments works",
|
||||
)
|
||||
limbo.run_test_fn(
|
||||
"SELECT percentile(value, 55) FROM test GROUP BY category;",
|
||||
lambda res: "15.5\n30.0\n51.0\n70.0" == res,
|
||||
"grouped aggregate percentile function with 1 argument works",
|
||||
)
|
||||
limbo.run_test_fn(
|
||||
"SELECT percentile_cont(value, 0.25) FROM test GROUP BY category;",
|
||||
lambda res: "12.5\n30.0\n45.0\n70.0" == res,
|
||||
"grouped aggregate percentile_cont function works",
|
||||
)
|
||||
limbo.run_test_fn(
|
||||
"SELECT percentile_disc(value, 0.55) FROM test GROUP BY category;",
|
||||
lambda res: "10.0\n30.0\n50.0\n70.0" == res,
|
||||
"grouped aggregate percentile_disc function works",
|
||||
)
|
||||
limbo.quit()
|
||||
|
||||
|
||||
# Encoders and decoders
|
||||
def validate_url_encode(a):
|
||||
return a == "%2Fhello%3Ftext%3D%28%E0%B2%A0_%E0%B2%A0%29"
|
||||
@@ -770,6 +803,7 @@ def main():
|
||||
test_regexp()
|
||||
test_uuid()
|
||||
test_aggregates()
|
||||
test_grouped_aggregates()
|
||||
test_crypto()
|
||||
test_series()
|
||||
test_ipaddr()
|
||||
|
||||
Reference in New Issue
Block a user