mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-02 16:04:20 +01:00
194 lines
5.7 KiB
Rust
194 lines
5.7 KiB
Rust
use turso_ext::{register_extension, AggFunc, AggregateDerive, Value};
|
|
|
|
register_extension! {
|
|
aggregates: { Median, Percentile, PercentileCont, PercentileDisc }
|
|
}
|
|
|
|
#[derive(AggregateDerive)]
|
|
struct Median;
|
|
|
|
impl AggFunc for Median {
|
|
type State = Vec<f64>;
|
|
type Error = &'static str;
|
|
const NAME: &'static str = "median";
|
|
const ARGS: i32 = 1;
|
|
|
|
fn step(state: &mut Self::State, args: &[Value]) {
|
|
if let Some(val) = args.first().and_then(Value::to_float) {
|
|
state.push(val);
|
|
}
|
|
}
|
|
|
|
fn finalize(state: Self::State) -> Result<Value, Self::Error> {
|
|
if state.is_empty() {
|
|
return Ok(Value::null());
|
|
}
|
|
|
|
let mut sorted = state;
|
|
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
|
|
let len = sorted.len();
|
|
if len % 2 == 1 {
|
|
Ok(Value::from_float(sorted[len / 2]))
|
|
} else {
|
|
let mid1 = sorted[len / 2 - 1];
|
|
let mid2 = sorted[len / 2];
|
|
Ok(Value::from_float((mid1 + mid2) / 2.0))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(AggregateDerive)]
|
|
struct Percentile;
|
|
|
|
impl AggFunc for Percentile {
|
|
type State = (Vec<f64>, Option<f64>, Option<Self::Error>);
|
|
type Error = &'static str;
|
|
const NAME: &'static str = "percentile";
|
|
const ARGS: i32 = 2;
|
|
|
|
fn step(state: &mut Self::State, args: &[Value]) {
|
|
let (values, p_value, err_value) = state;
|
|
if let (Some(y), Some(p)) = (
|
|
args.first().and_then(Value::to_float),
|
|
args.get(1).and_then(Value::to_float),
|
|
) {
|
|
if !(0.0..=100.0).contains(&p) {
|
|
err_value.get_or_insert("Invalid percentile value");
|
|
return;
|
|
}
|
|
|
|
if let Some(existing_p) = *p_value {
|
|
if (existing_p - p).abs() >= 0.001 {
|
|
err_value.get_or_insert("Inconsistent percentile values across rows");
|
|
return;
|
|
}
|
|
} else {
|
|
*p_value = Some(p);
|
|
}
|
|
values.push(y);
|
|
}
|
|
}
|
|
|
|
fn finalize(state: Self::State) -> Result<Value, Self::Error> {
|
|
let (mut values, p_value, err_value) = state;
|
|
if values.is_empty() {
|
|
return Ok(Value::null());
|
|
}
|
|
if let Some(err) = err_value {
|
|
return Err(err);
|
|
}
|
|
if values.len() == 1 {
|
|
return Ok(Value::from_float(values[0]));
|
|
}
|
|
|
|
let p = p_value.unwrap();
|
|
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
let n = values.len() as f64;
|
|
let index = p * (n - 1.0) / 100.0;
|
|
let lower = index.floor() as usize;
|
|
let upper = index.ceil() as usize;
|
|
|
|
if lower == upper {
|
|
Ok(Value::from_float(values[lower]))
|
|
} else {
|
|
let weight = index - lower as f64;
|
|
Ok(Value::from_float(
|
|
values[lower] * (1.0 - weight) + values[upper] * weight,
|
|
))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(AggregateDerive)]
|
|
struct PercentileCont;
|
|
|
|
impl AggFunc for PercentileCont {
|
|
type State = (Vec<f64>, Option<f64>, Option<Self::Error>);
|
|
type Error = &'static str;
|
|
const NAME: &'static str = "percentile_cont";
|
|
const ARGS: i32 = 2;
|
|
|
|
fn step(state: &mut Self::State, args: &[Value]) {
|
|
let (values, p_value, err_state) = state;
|
|
if let (Some(y), Some(p)) = (
|
|
args.first().and_then(Value::to_float),
|
|
args.get(1).and_then(Value::to_float),
|
|
) {
|
|
if !(0.0..=1.0).contains(&p) {
|
|
err_state.get_or_insert("Percentile value must be between 0.0 and 1.0 inclusive");
|
|
return;
|
|
}
|
|
|
|
if let Some(existing_p) = *p_value {
|
|
if (existing_p - p).abs() >= 0.001 {
|
|
err_state.get_or_insert("Inconsistent percentile values across rows");
|
|
return;
|
|
}
|
|
} else {
|
|
*p_value = Some(p);
|
|
}
|
|
values.push(y);
|
|
}
|
|
}
|
|
|
|
fn finalize(state: Self::State) -> Result<Value, Self::Error> {
|
|
let (mut values, p_value, err_state) = state;
|
|
if values.is_empty() {
|
|
return Ok(Value::null());
|
|
}
|
|
if let Some(err) = err_state {
|
|
return Err(err);
|
|
}
|
|
if values.len() == 1 {
|
|
return Ok(Value::from_float(values[0]));
|
|
}
|
|
|
|
let p = p_value.unwrap();
|
|
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
let n = values.len() as f64;
|
|
let index = p * (n - 1.0);
|
|
let lower = index.floor() as usize;
|
|
let upper = index.ceil() as usize;
|
|
|
|
if lower == upper {
|
|
Ok(Value::from_float(values[lower]))
|
|
} else {
|
|
let weight = index - lower as f64;
|
|
Ok(Value::from_float(
|
|
values[lower] * (1.0 - weight) + values[upper] * weight,
|
|
))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(AggregateDerive)]
|
|
struct PercentileDisc;
|
|
|
|
impl AggFunc for PercentileDisc {
|
|
type State = (Vec<f64>, Option<f64>, Option<Self::Error>);
|
|
type Error = &'static str;
|
|
const NAME: &'static str = "percentile_disc";
|
|
const ARGS: i32 = 2;
|
|
|
|
fn step(state: &mut Self::State, args: &[Value]) {
|
|
Percentile::step(state, args);
|
|
}
|
|
|
|
fn finalize(state: Self::State) -> Result<Value, Self::Error> {
|
|
let (mut values, p_value, err_value) = state;
|
|
if values.is_empty() {
|
|
return Ok(Value::null());
|
|
}
|
|
if let Some(err) = err_value {
|
|
return Err(err);
|
|
}
|
|
|
|
let p = p_value.unwrap();
|
|
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
let n = values.len() as f64;
|
|
let index = (p * (n - 1.0)).floor() as usize;
|
|
Ok(Value::from_float(values[index]))
|
|
}
|
|
}
|