diff --git a/Cargo.lock b/Cargo.lock index 6284a9e8c..4333604a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1274,6 +1274,13 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "limbo_median" +version = "0.0.12" +dependencies = [ + "limbo_ext", +] + [[package]] name = "limbo_regexp" version = "0.0.12" diff --git a/Cargo.toml b/Cargo.toml index 690d8feec..83db6dfa2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ members = [ "macros", "simulator", "sqlite3", - "test", + "test", "extensions/median", ] exclude = ["perf/latency/limbo"] diff --git a/extensions/median/Cargo.toml b/extensions/median/Cargo.toml new file mode 100644 index 000000000..615112628 --- /dev/null +++ b/extensions/median/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "limbo_median" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +crate-type = ["cdylib"] + +[dependencies] +limbo_ext = { path = "../core" } diff --git a/extensions/median/src/lib.rs b/extensions/median/src/lib.rs new file mode 100644 index 000000000..4b123369d --- /dev/null +++ b/extensions/median/src/lib.rs @@ -0,0 +1,44 @@ +use limbo_ext::{register_extension, AggFunc, AggregateDerive, Value}; + +register_extension! { + aggregates: { MedianState }, +} + +#[derive(AggregateDerive)] +struct MedianState; + +impl AggFunc for MedianState { + type State = Vec; + + fn name(&self) -> &'static str { + "median" + } + + fn args(&self) -> i32 { + 1 + } + + fn step(state: &mut Self::State, args: &[Value]) { + if let Some(val) = args.first().and_then(Value::to_float) { + state.push(val); + } + } + + fn finalize(state: Self::State) -> Value { + if state.is_empty() { + return Value::null(); + } + + let mut sorted = state; + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + let len = sorted.len(); + if len % 2 == 1 { + Value::from_float(sorted[len / 2]) + } else { + let mid1 = sorted[len / 2 - 1]; + let mid2 = sorted[len / 2]; + Value::from_float((mid1 + mid2) / 2.0) + } + } +} diff --git a/testing/extensions.py b/testing/extensions.py index 55c380f5b..ad87c65ca 100755 --- a/testing/extensions.py +++ b/testing/extensions.py @@ -7,6 +7,16 @@ import time sqlite_exec = "./target/debug/limbo" sqlite_flags = os.getenv("SQLITE_FLAGS", "-q").split(" ") +test_data = """CREATE TABLE numbers ( id INTEGER PRIMARY KEY, value FLOAT NOT NULL); +INSERT INTO numbers (value) VALUES (1.0); +INSERT INTO numbers (value) VALUES (2.0); +INSERT INTO numbers (value) VALUES (3.0); +INSERT INTO numbers (value) VALUES (4.0); +INSERT INTO numbers (value) VALUES (5.0); +INSERT INTO numbers (value) VALUES (6.0); +INSERT INTO numbers (value) VALUES (7.0); +""" + def init_limbo(): pipe = subprocess.Popen( @@ -16,6 +26,7 @@ def init_limbo(): stderr=subprocess.PIPE, bufsize=0, ) + write_to_pipe(pipe, test_data) return pipe @@ -180,11 +191,39 @@ def test_regexp(pipe): ) +def validate_median(res): + return res == "4.0" + + +def test_aggregates(pipe): + extension_path = "./target/debug/liblimbo_median.so" + # assert no function before extension loads + run_test( + pipe, + "SELECT median(1);", + returns_null, + "median agg function returns null when ext not loaded", + ) + run_test( + pipe, + f".load {extension_path}", + returns_null, + "load extension command works properly", + ) + run_test( + pipe, + "select median(value) from numbers;", + validate_median, + "median agg function works", + ) + + def main(): pipe = init_limbo() try: test_regexp(pipe) test_uuid(pipe) + test_aggregates(pipe) except Exception as e: print(f"Test FAILED: {e}") pipe.terminate()