mirror of
https://github.com/aljazceru/turso.git
synced 2026-01-26 11:24:32 +01:00
Merge pull request #197 from brayanjuls/min_max_scalar_func
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AggFunc {
|
||||
@@ -38,6 +37,8 @@ pub enum SingleRowFunc {
|
||||
Trim,
|
||||
Round,
|
||||
Length,
|
||||
Min,
|
||||
Max,
|
||||
}
|
||||
|
||||
impl ToString for SingleRowFunc {
|
||||
@@ -52,6 +53,8 @@ impl ToString for SingleRowFunc {
|
||||
SingleRowFunc::Trim => "trim".to_string(),
|
||||
SingleRowFunc::Round => "round".to_string(),
|
||||
SingleRowFunc::Length => "length".to_string(),
|
||||
SingleRowFunc::Min => "min".to_string(),
|
||||
SingleRowFunc::Max => "max".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -62,16 +65,17 @@ pub enum Func {
|
||||
SingleRow(SingleRowFunc),
|
||||
}
|
||||
|
||||
impl FromStr for Func {
|
||||
type Err = ();
|
||||
impl Func{
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
pub fn resolve_function(name: &str, arg_count:usize) -> Result<Func, ()>{
|
||||
match name {
|
||||
"avg" => Ok(Func::Agg(AggFunc::Avg)),
|
||||
"count" => Ok(Func::Agg(AggFunc::Count)),
|
||||
"group_concat" => Ok(Func::Agg(AggFunc::GroupConcat)),
|
||||
"max" => Ok(Func::Agg(AggFunc::Max)),
|
||||
"min" => Ok(Func::Agg(AggFunc::Min)),
|
||||
"max" if arg_count == 0 || arg_count == 1 => Ok(Func::Agg(AggFunc::Max)),
|
||||
"max" if arg_count > 1 => Ok(Func::SingleRow(SingleRowFunc::Max)),
|
||||
"min" if arg_count == 0 || arg_count == 1 => Ok(Func::Agg(AggFunc::Min)),
|
||||
"min" if arg_count > 1 => Ok(Func::SingleRow(SingleRowFunc::Min)),
|
||||
"string_agg" => Ok(Func::Agg(AggFunc::StringAgg)),
|
||||
"sum" => Ok(Func::Agg(AggFunc::Sum)),
|
||||
"total" => Ok(Func::Agg(AggFunc::Total)),
|
||||
@@ -87,4 +91,4 @@ impl FromStr for Func {
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -225,7 +225,8 @@ pub fn translate_expr(
|
||||
args,
|
||||
filter_over: _,
|
||||
} => {
|
||||
let func_type: Option<Func> = normalize_ident(name.0.as_str()).as_str().parse().ok();
|
||||
let args_count = if let Some(args) = args { args.len() } else { 0 };
|
||||
let func_type: Option<Func> = Func::resolve_function(normalize_ident(name.0.as_str()).as_str(),args_count).ok();
|
||||
|
||||
match func_type {
|
||||
Some(Func::Agg(_)) => {
|
||||
@@ -377,6 +378,60 @@ pub fn translate_expr(
|
||||
});
|
||||
Ok(target_register)
|
||||
}
|
||||
SingleRowFunc::Min => {
|
||||
let args = if let Some(args) = args {
|
||||
if args.len() < 1 {
|
||||
anyhow::bail!(
|
||||
"Parse error: min function with less than one argument"
|
||||
);
|
||||
}
|
||||
args
|
||||
} else {
|
||||
anyhow::bail!("Parse error: min function with no arguments");
|
||||
};
|
||||
for arg in args {
|
||||
let reg = program.alloc_register();
|
||||
let _ = translate_expr(program, select, arg, reg, cursor_hint)?;
|
||||
match arg {
|
||||
ast::Expr::Literal(_) => program.mark_last_insn_constant(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
program.emit_insn(Insn::Function {
|
||||
start_reg: target_register + 1,
|
||||
dest: target_register,
|
||||
func: SingleRowFunc::Min,
|
||||
});
|
||||
Ok(target_register)
|
||||
}
|
||||
SingleRowFunc::Max => {
|
||||
let args = if let Some(args) = args {
|
||||
if args.len() < 1 {
|
||||
anyhow::bail!(
|
||||
"Parse error: max function with less than one argument"
|
||||
);
|
||||
}
|
||||
args
|
||||
} else {
|
||||
anyhow::bail!("Parse error: max function with no arguments");
|
||||
};
|
||||
for arg in args {
|
||||
let reg = program.alloc_register();
|
||||
let _ = translate_expr(program, select, arg, reg, cursor_hint)?;
|
||||
match arg {
|
||||
ast::Expr::Literal(_) => program.mark_last_insn_constant(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
program.emit_insn(Insn::Function {
|
||||
start_reg: target_register + 1,
|
||||
dest: target_register,
|
||||
func: SingleRowFunc::Max,
|
||||
});
|
||||
Ok(target_register)
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
@@ -531,7 +586,8 @@ pub fn analyze_expr<'a>(expr: &'a Expr, column_info_out: &mut ColumnInfo<'a>) {
|
||||
args,
|
||||
filter_over: _,
|
||||
} => {
|
||||
let func_type = match normalize_ident(name.0.as_str()).as_str().parse() {
|
||||
let args_count = if let Some(args) = args { args.len() } else { 0 };
|
||||
let func_type = match Func::resolve_function(normalize_ident(name.0.as_str()).as_str(),args_count) {
|
||||
Ok(func) => Some(func),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
72
core/vdbe.rs
72
core/vdbe.rs
@@ -1362,6 +1362,32 @@ impl Program {
|
||||
state.registers[*dest] = result;
|
||||
state.pc += 1;
|
||||
}
|
||||
SingleRowFunc::Min => {
|
||||
let start_reg = *start_reg;
|
||||
let reg_values = state.registers[start_reg..state.registers.len()]
|
||||
.iter()
|
||||
.collect();
|
||||
let min_fn = |a, b| if a < b { a } else { b };
|
||||
if let Some(value) = exec_minmax(reg_values, min_fn) {
|
||||
state.registers[*dest] = value;
|
||||
} else {
|
||||
state.registers[*dest] = OwnedValue::Null;
|
||||
}
|
||||
state.pc += 1;
|
||||
}
|
||||
SingleRowFunc::Max => {
|
||||
let start_reg = *start_reg;
|
||||
let reg_values = state.registers[start_reg..state.registers.len()]
|
||||
.iter()
|
||||
.collect();
|
||||
let max_fn = |a, b| if a > b { a } else { b };
|
||||
if let Some(value) = exec_minmax(reg_values, max_fn) {
|
||||
state.registers[*dest] = value;
|
||||
} else {
|
||||
state.registers[*dest] = OwnedValue::Null;
|
||||
}
|
||||
state.pc += 1;
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1990,6 +2016,13 @@ fn exec_like(pattern: &str, text: &str) -> bool {
|
||||
re.is_match(text)
|
||||
}
|
||||
|
||||
fn exec_minmax<'a>(
|
||||
regs: Vec<&'a OwnedValue>,
|
||||
op: fn(&'a OwnedValue, &'a OwnedValue) -> &'a OwnedValue,
|
||||
) -> Option<OwnedValue> {
|
||||
regs.into_iter().reduce(|a, b| op(a, b)).cloned()
|
||||
}
|
||||
|
||||
fn exec_round(reg: &OwnedValue, precision: Option<OwnedValue>) -> OwnedValue {
|
||||
let precision = match precision {
|
||||
Some(OwnedValue::Text(x)) => x.parse().unwrap_or(0.0),
|
||||
@@ -2045,7 +2078,7 @@ fn exec_if(reg: &OwnedValue, null_reg: &OwnedValue, not: bool) -> bool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
exec_abs, exec_if, exec_length, exec_like, exec_lower, exec_random, exec_round, exec_trim,
|
||||
exec_abs, exec_if, exec_length, exec_like, exec_lower, exec_minmax, exec_random, exec_round, exec_trim,
|
||||
exec_upper, OwnedValue,
|
||||
};
|
||||
use std::rc::Rc;
|
||||
@@ -2070,6 +2103,43 @@ mod tests {
|
||||
assert_eq!(exec_length(&expected_blob), expected_len);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_minmax() {
|
||||
let min_fn = |a, b| if a < b { a } else { b };
|
||||
let max_fn = |a, b| if a > b { a } else { b };
|
||||
let input_int_vec = vec![&OwnedValue::Integer(-1), &OwnedValue::Integer(10)];
|
||||
assert_eq!(
|
||||
exec_minmax(input_int_vec.clone(), min_fn),
|
||||
Some(OwnedValue::Integer(-1))
|
||||
);
|
||||
assert_eq!(
|
||||
exec_minmax(input_int_vec.clone(), max_fn),
|
||||
Some(OwnedValue::Integer(10))
|
||||
);
|
||||
|
||||
let str1 = OwnedValue::Text(Rc::new(String::from("A")));
|
||||
let str2 = OwnedValue::Text(Rc::new(String::from("z")));
|
||||
let input_str_vec = vec![&str2, &str1];
|
||||
assert_eq!(
|
||||
exec_minmax(input_str_vec.clone(), min_fn),
|
||||
Some(OwnedValue::Text(Rc::new(String::from("A"))))
|
||||
);
|
||||
assert_eq!(
|
||||
exec_minmax(input_str_vec.clone(), max_fn),
|
||||
Some(OwnedValue::Text(Rc::new(String::from("z"))))
|
||||
);
|
||||
|
||||
let input_null_vec = vec![&OwnedValue::Null, &OwnedValue::Null];
|
||||
assert_eq!(
|
||||
exec_minmax(input_null_vec.clone(), min_fn),
|
||||
Some(OwnedValue::Null)
|
||||
);
|
||||
assert_eq!(
|
||||
exec_minmax(input_null_vec.clone(), max_fn),
|
||||
Some(OwnedValue::Null)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trim() {
|
||||
let input_str = OwnedValue::Text(Rc::new(String::from(" Bob and Alice ")));
|
||||
|
||||
@@ -141,4 +141,27 @@ do_execsql_test length-null {
|
||||
|
||||
do_execsql_test length-empty-text {
|
||||
SELECT length('');
|
||||
} {0}
|
||||
} {0}
|
||||
do_execsql_test min-number {
|
||||
select min(-10,2,3)
|
||||
} {-10}
|
||||
|
||||
do_execsql_test min-str {
|
||||
select min('b','a','z')
|
||||
} {a}
|
||||
|
||||
do_execsql_test min-null {
|
||||
select min(null,null)
|
||||
} {}
|
||||
|
||||
do_execsql_test max-number {
|
||||
select max(-10,2,3)
|
||||
} {3}
|
||||
|
||||
do_execsql_test max-str {
|
||||
select max('b','a','z')
|
||||
} {z}
|
||||
|
||||
do_execsql_test max-null {
|
||||
select max(null,null)
|
||||
} {}
|
||||
Reference in New Issue
Block a user