Files
turso/extensions/percentile/src/lib.rs
2025-06-29 12:14:08 +03:00

194 lines
5.7 KiB
Rust

use turso_ext::{register_extension, AggFunc, AggregateDerive, Value};
register_extension! {
aggregates: { Median, Percentile, PercentileCont, PercentileDisc }
}
#[derive(AggregateDerive)]
struct Median;
impl AggFunc for Median {
type State = Vec<f64>;
type Error = &'static str;
const NAME: &'static str = "median";
const ARGS: i32 = 1;
fn step(state: &mut Self::State, args: &[Value]) {
if let Some(val) = args.first().and_then(Value::to_float) {
state.push(val);
}
}
fn finalize(state: Self::State) -> Result<Value, Self::Error> {
if state.is_empty() {
return Ok(Value::null());
}
let mut sorted = state;
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let len = sorted.len();
if len % 2 == 1 {
Ok(Value::from_float(sorted[len / 2]))
} else {
let mid1 = sorted[len / 2 - 1];
let mid2 = sorted[len / 2];
Ok(Value::from_float((mid1 + mid2) / 2.0))
}
}
}
#[derive(AggregateDerive)]
struct Percentile;
impl AggFunc for Percentile {
type State = (Vec<f64>, Option<f64>, Option<Self::Error>);
type Error = &'static str;
const NAME: &'static str = "percentile";
const ARGS: i32 = 2;
fn step(state: &mut Self::State, args: &[Value]) {
let (values, p_value, err_value) = state;
if let (Some(y), Some(p)) = (
args.first().and_then(Value::to_float),
args.get(1).and_then(Value::to_float),
) {
if !(0.0..=100.0).contains(&p) {
err_value.get_or_insert("Invalid percentile value");
return;
}
if let Some(existing_p) = *p_value {
if (existing_p - p).abs() >= 0.001 {
err_value.get_or_insert("Inconsistent percentile values across rows");
return;
}
} else {
*p_value = Some(p);
}
values.push(y);
}
}
fn finalize(state: Self::State) -> Result<Value, Self::Error> {
let (mut values, p_value, err_value) = state;
if values.is_empty() {
return Ok(Value::null());
}
if let Some(err) = err_value {
return Err(err);
}
if values.len() == 1 {
return Ok(Value::from_float(values[0]));
}
let p = p_value.unwrap();
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let n = values.len() as f64;
let index = p * (n - 1.0) / 100.0;
let lower = index.floor() as usize;
let upper = index.ceil() as usize;
if lower == upper {
Ok(Value::from_float(values[lower]))
} else {
let weight = index - lower as f64;
Ok(Value::from_float(
values[lower] * (1.0 - weight) + values[upper] * weight,
))
}
}
}
#[derive(AggregateDerive)]
struct PercentileCont;
impl AggFunc for PercentileCont {
type State = (Vec<f64>, Option<f64>, Option<Self::Error>);
type Error = &'static str;
const NAME: &'static str = "percentile_cont";
const ARGS: i32 = 2;
fn step(state: &mut Self::State, args: &[Value]) {
let (values, p_value, err_state) = state;
if let (Some(y), Some(p)) = (
args.first().and_then(Value::to_float),
args.get(1).and_then(Value::to_float),
) {
if !(0.0..=1.0).contains(&p) {
err_state.get_or_insert("Percentile value must be between 0.0 and 1.0 inclusive");
return;
}
if let Some(existing_p) = *p_value {
if (existing_p - p).abs() >= 0.001 {
err_state.get_or_insert("Inconsistent percentile values across rows");
return;
}
} else {
*p_value = Some(p);
}
values.push(y);
}
}
fn finalize(state: Self::State) -> Result<Value, Self::Error> {
let (mut values, p_value, err_state) = state;
if values.is_empty() {
return Ok(Value::null());
}
if let Some(err) = err_state {
return Err(err);
}
if values.len() == 1 {
return Ok(Value::from_float(values[0]));
}
let p = p_value.unwrap();
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let n = values.len() as f64;
let index = p * (n - 1.0);
let lower = index.floor() as usize;
let upper = index.ceil() as usize;
if lower == upper {
Ok(Value::from_float(values[lower]))
} else {
let weight = index - lower as f64;
Ok(Value::from_float(
values[lower] * (1.0 - weight) + values[upper] * weight,
))
}
}
}
#[derive(AggregateDerive)]
struct PercentileDisc;
impl AggFunc for PercentileDisc {
type State = (Vec<f64>, Option<f64>, Option<Self::Error>);
type Error = &'static str;
const NAME: &'static str = "percentile_disc";
const ARGS: i32 = 2;
fn step(state: &mut Self::State, args: &[Value]) {
Percentile::step(state, args);
}
fn finalize(state: Self::State) -> Result<Value, Self::Error> {
let (mut values, p_value, err_value) = state;
if values.is_empty() {
return Ok(Value::null());
}
if let Some(err) = err_value {
return Err(err);
}
let p = p_value.unwrap();
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let n = values.len() as f64;
let index = (p * (n - 1.0)).floor() as usize;
Ok(Value::from_float(values[index]))
}
}