From 0839211a495fcbdc57741b6534613af2f93d3f01 Mon Sep 17 00:00:00 2001 From: jussisaurio Date: Sat, 14 Sep 2024 10:31:42 +0300 Subject: [PATCH] Pass FuncCtx to Insn::Function to keep track of arg count --- core/function.rs | 16 ++ core/translate/expr.rs | 65 +++--- core/vdbe/explain.rs | 9 +- core/vdbe/mod.rs | 361 +++++++++++++++------------------- testing/scalar-functions.test | 4 + 5 files changed, 220 insertions(+), 235 deletions(-) diff --git a/core/function.rs b/core/function.rs index ae1ab0966..a368867a6 100644 --- a/core/function.rs +++ b/core/function.rs @@ -104,6 +104,22 @@ pub enum Func { Json(JsonFunc), } +impl Func { + pub fn to_string(&self) -> String { + match self { + Func::Agg(agg_func) => agg_func.to_string().to_string(), + Func::Scalar(scalar_func) => scalar_func.to_string(), + Func::Json(json_func) => json_func.to_string(), + } + } +} + +#[derive(Debug)] +pub struct FuncCtx { + pub func: Func, + pub arg_count: usize, +} + impl Func { pub fn resolve_function(name: &str, arg_count: usize) -> Result { match name { diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 3db50331d..6e18c8605 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -2,7 +2,7 @@ use crate::{function::JsonFunc, Result}; use sqlite3_parser::ast::{self, UnaryOperator}; use std::rc::Rc; -use crate::function::{AggFunc, Func, ScalarFunc}; +use crate::function::{AggFunc, Func, FuncCtx, ScalarFunc}; use crate::schema::Type; use crate::util::normalize_ident; use crate::{ @@ -457,9 +457,12 @@ pub fn translate_condition_expr( // Only constant patterns for LIKE are supported currently, so this // is always 1 constant_mask: 1, - func: crate::vdbe::Func::Scalar(ScalarFunc::Like), start_reg: pattern_reg, dest: cur_reg, + func: FuncCtx { + func: Func::Scalar(ScalarFunc::Like), + arg_count: 2, + }, }); } ast::LikeOperator::Glob => todo!(), @@ -524,8 +527,8 @@ pub fn translate_expr( ast::Expr::Between { .. } => todo!(), ast::Expr::Binary(e1, op, e2) => { let e1_reg = program.alloc_register(); - let e2_reg = program.alloc_register(); let _ = translate_expr(program, referenced_tables, e1, e1_reg, cursor_hint)?; + let e2_reg = program.alloc_register(); let _ = translate_expr(program, referenced_tables, e2, e2_reg, cursor_hint)?; match op { @@ -634,12 +637,20 @@ pub fn translate_expr( let func_type: Option = Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), args_count).ok(); - match func_type { - Some(Func::Agg(_)) => { + if func_type.is_none() { + crate::bail_parse_error!("unknown function {}", name.0); + } + + let func_ctx = FuncCtx { + func: func_type.unwrap(), + arg_count: args_count, + }; + + match &func_ctx.func { + Func::Agg(_) => { crate::bail_parse_error!("aggregation function in non-aggregation context") } - - Some(Func::Json(j)) => match j { + Func::Json(j) => match j { JsonFunc::JSON => { let args = if let Some(args) = args { if args.len() != 1 { @@ -661,12 +672,12 @@ pub fn translate_expr( constant_mask: 0, start_reg: regs, dest: target_register, - func: crate::vdbe::Func::Json(j), + func: func_ctx, }); Ok(target_register) } }, - Some(Func::Scalar(srf)) => { + Func::Scalar(srf) => { match srf { ScalarFunc::Char => { let args = args.clone().unwrap_or_else(Vec::new); @@ -680,7 +691,7 @@ pub fn translate_expr( constant_mask: 0, start_reg: target_register + 1, dest: target_register, - func: crate::vdbe::Func::Scalar(srf), + func: func_ctx, }); Ok(target_register) } @@ -742,7 +753,7 @@ pub fn translate_expr( constant_mask: 0, start_reg: target_register + 1, dest: target_register, - func: crate::vdbe::Func::Scalar(srf), + func: func_ctx, }); Ok(target_register) } @@ -822,7 +833,7 @@ pub fn translate_expr( constant_mask: 1, start_reg: target_register + 1, dest: target_register, - func: crate::vdbe::Func::Scalar(srf), + func: func_ctx, }); Ok(target_register) } @@ -859,7 +870,7 @@ pub fn translate_expr( constant_mask: 0, start_reg: regs, dest: target_register, - func: crate::vdbe::Func::Scalar(srf), + func: func_ctx, }); Ok(target_register) } @@ -875,7 +886,7 @@ pub fn translate_expr( constant_mask: 0, start_reg: regs, dest: target_register, - func: crate::vdbe::Func::Scalar(srf), + func: func_ctx, }); Ok(target_register) } @@ -897,7 +908,7 @@ pub fn translate_expr( constant_mask: 0, start_reg: target_register + 1, dest: target_register, - func: crate::vdbe::Func::Scalar(ScalarFunc::Date), + func: func_ctx, }); Ok(target_register) } @@ -949,7 +960,7 @@ pub fn translate_expr( constant_mask: 0, start_reg: str_reg, dest: target_register, - func: crate::vdbe::Func::Scalar(ScalarFunc::Substring), + func: func_ctx, }); Ok(target_register) } @@ -974,7 +985,7 @@ pub fn translate_expr( constant_mask: 0, start_reg, dest: target_register, - func: crate::vdbe::Func::Scalar(srf), + func: func_ctx, }); Ok(target_register) } @@ -996,7 +1007,7 @@ pub fn translate_expr( constant_mask: 0, start_reg: target_register + 1, dest: target_register, - func: crate::vdbe::Func::Scalar(ScalarFunc::Time), + func: func_ctx, }); Ok(target_register) } @@ -1030,7 +1041,7 @@ pub fn translate_expr( constant_mask: 0, start_reg: target_register + 1, dest: target_register, - func: crate::vdbe::Func::Scalar(srf), + func: func_ctx, }); Ok(target_register) } @@ -1064,7 +1075,7 @@ pub fn translate_expr( constant_mask: 0, start_reg: target_register + 1, dest: target_register, - func: crate::vdbe::Func::Scalar(ScalarFunc::Min), + func: func_ctx, }); Ok(target_register) } @@ -1098,7 +1109,7 @@ pub fn translate_expr( constant_mask: 0, start_reg: target_register + 1, dest: target_register, - func: crate::vdbe::Func::Scalar(ScalarFunc::Max), + func: func_ctx, }); Ok(target_register) } @@ -1114,7 +1125,6 @@ pub fn translate_expr( crate::bail_parse_error!("nullif function with no arguments"); }; - let func_reg = program.alloc_register(); let first_reg = program.alloc_register(); translate_expr( program, @@ -1133,18 +1143,15 @@ pub fn translate_expr( )?; program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: func_reg, + start_reg: first_reg, dest: target_register, - func: crate::vdbe::Func::Scalar(srf), + func: func_ctx, }); Ok(target_register) } } } - None => { - crate::bail_parse_error!("unknown function {}", name.0); - } } } ast::Expr::FunctionCallStar { .. } => todo!(), @@ -1528,7 +1535,7 @@ pub fn translate_aggregation( expr, expr_reg, cursor_hint, - ); + )?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, @@ -1549,7 +1556,7 @@ pub fn translate_aggregation( expr, expr_reg, cursor_hint, - ); + )?; program.emit_insn(Insn::AggStep { acc_reg: target_register, col: expr_reg, diff --git a/core/vdbe/explain.rs b/core/vdbe/explain.rs index 38f2d6bcc..cddc7e9e3 100644 --- a/core/vdbe/explain.rs +++ b/core/vdbe/explain.rs @@ -542,9 +542,14 @@ pub fn insn_to_str( *constant_mask, *start_reg as i32, *dest as i32, - OwnedValue::Text(Rc::new(func.to_string())), + OwnedValue::Text(Rc::new(func.func.to_string())), 0, - format!("r[{}]=func(r[{}..])", dest, start_reg), + format!( + "r[{}]=func(r[{}..{}])", + dest, + start_reg, + start_reg + func.arg_count - 1 + ), ), Insn::InitCoroutine { yield_reg, diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 8a0296e68..eeecca3ec 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -24,7 +24,7 @@ pub mod sorter; mod datetime; use crate::error::LimboError; -use crate::function::{AggFunc, JsonFunc, ScalarFunc}; +use crate::function::{AggFunc, FuncCtx, JsonFunc, ScalarFunc}; use crate::json::get_json; use crate::pseudo::PseudoCursor; use crate::schema::Table; @@ -305,7 +305,7 @@ pub enum Insn { constant_mask: i32, // P1 start_reg: usize, // P2, start of argument registers dest: usize, // P3 - func: Func, // P4 + func: FuncCtx, // P4 }, InitCoroutine { @@ -1179,216 +1179,169 @@ impl Program { func, start_reg, dest, - } => match func { - Func::Json(JsonFunc::JSON) => { - let json_value = &state.registers[*start_reg]; - let json_str = get_json(json_value); - match json_str { - Ok(json) => state.registers[*dest] = json, - Err(e) => return Err(e), + } => { + let arg_count = func.arg_count; + match &func.func { + crate::function::Func::Json(JsonFunc::JSON) => { + let json_value = &state.registers[*start_reg]; + let json_str = get_json(json_value); + match json_str { + Ok(json) => state.registers[*dest] = json, + Err(e) => return Err(e), + } } - state.pc += 1; - } - Func::Scalar(ScalarFunc::Char) => { - let start_reg = *start_reg; - let reg_values = state.registers[start_reg..state.registers.len()].to_vec(); - state.registers[*dest] = exec_char(reg_values); - state.pc += 1; - } - Func::Scalar(ScalarFunc::Coalesce) => {} - Func::Scalar(ScalarFunc::Concat) => { - let start_reg = *start_reg; - let result = exec_concat(&state.registers[start_reg..]); - state.registers[*dest] = result; - state.pc += 1; - } - Func::Scalar(ScalarFunc::IfNull) => {} - Func::Scalar(ScalarFunc::Like) => { - let start_reg = *start_reg; - assert!( - start_reg + 2 <= state.registers.len(), - "not enough registers {} < {}", - start_reg, - state.registers.len() - ); - let pattern = &state.registers[start_reg]; - let text = &state.registers[start_reg + 1]; - let result = match (pattern, text) { - (OwnedValue::Text(pattern), OwnedValue::Text(text)) => { - let cache = if *constant_mask > 0 { - Some(&mut state.regex_cache) - } else { - None + crate::function::Func::Scalar(scalar_func) => match scalar_func { + ScalarFunc::Char => { + let reg_values = + state.registers[*start_reg..*start_reg + arg_count].to_vec(); + state.registers[*dest] = exec_char(reg_values); + } + ScalarFunc::Coalesce => {} + ScalarFunc::Concat => { + let result = exec_concat( + &state.registers[*start_reg..*start_reg + arg_count], + ); + state.registers[*dest] = result; + } + ScalarFunc::IfNull => {} + ScalarFunc::Like => { + let pattern = &state.registers[*start_reg]; + let text = &state.registers[*start_reg + 1]; + let result = match (pattern, text) { + (OwnedValue::Text(pattern), OwnedValue::Text(text)) => { + let cache = if *constant_mask > 0 { + Some(&mut state.regex_cache) + } else { + None + }; + OwnedValue::Integer(exec_like(cache, pattern, text) as i64) + } + _ => { + unreachable!("Like on non-text registers"); + } }; - OwnedValue::Integer(exec_like(cache, pattern, text) as i64) + state.registers[*dest] = result; } - _ => { - unreachable!("Like on non-text registers"); + ScalarFunc::Abs + | ScalarFunc::Lower + | ScalarFunc::Upper + | ScalarFunc::Length + | ScalarFunc::Unicode + | ScalarFunc::Quote => { + let reg_value = state.registers[*start_reg].borrow_mut(); + let result = match scalar_func { + ScalarFunc::Abs => exec_abs(reg_value), + ScalarFunc::Lower => exec_lower(reg_value), + ScalarFunc::Upper => exec_upper(reg_value), + ScalarFunc::Length => Some(exec_length(reg_value)), + ScalarFunc::Unicode => Some(exec_unicode(reg_value)), + ScalarFunc::Quote => Some(exec_quote(reg_value)), + _ => unreachable!(), + }; + state.registers[*dest] = result.unwrap_or(OwnedValue::Null); } - }; - state.registers[*dest] = result; - state.pc += 1; - } - Func::Scalar(ScalarFunc::Abs) => { - let reg_value = state.registers[*start_reg].borrow_mut(); - if let Some(value) = exec_abs(reg_value) { - state.registers[*dest] = value; - } else { - state.registers[*dest] = OwnedValue::Null; - } - state.pc += 1; - } - Func::Scalar(ScalarFunc::Upper) => { - let reg_value = state.registers[*start_reg].borrow_mut(); - if let Some(value) = exec_upper(reg_value) { - state.registers[*dest] = value; - } else { - state.registers[*dest] = OwnedValue::Null; - } - state.pc += 1; - } - Func::Scalar(ScalarFunc::Lower) => { - let reg_value = state.registers[*start_reg].borrow_mut(); - if let Some(value) = exec_lower(reg_value) { - state.registers[*dest] = value; - } else { - state.registers[*dest] = OwnedValue::Null; - } - state.pc += 1; - } - Func::Scalar(ScalarFunc::Length) => { - let reg_value = state.registers[*start_reg].borrow_mut(); - state.registers[*dest] = exec_length(reg_value); - state.pc += 1; - } - Func::Scalar(ScalarFunc::Random) => { - state.registers[*dest] = exec_random(); - state.pc += 1; - } - Func::Scalar(ScalarFunc::Trim) => { - let start_reg = *start_reg; - let reg_value = state.registers[start_reg].clone(); - let pattern_value = state.registers.get(start_reg + 1).cloned(); - - let result = exec_trim(®_value, pattern_value); - - state.registers[*dest] = result; - state.pc += 1; - } - Func::Scalar(ScalarFunc::LTrim) => { - let start_reg = *start_reg; - let reg_value = state.registers[start_reg].clone(); - let pattern_value = state.registers.get(start_reg + 1).cloned(); - - let result = exec_ltrim(®_value, pattern_value); - - state.registers[*dest] = result; - state.pc += 1; - } - Func::Scalar(ScalarFunc::RTrim) => { - let start_reg = *start_reg; - let reg_value = state.registers[start_reg].clone(); - let pattern_value = state.registers.get(start_reg + 1).cloned(); - - let result = exec_rtrim(®_value, pattern_value); - - state.registers[*dest] = result; - state.pc += 1; - } - Func::Scalar(ScalarFunc::Round) => { - let start_reg = *start_reg; - let reg_value = state.registers[start_reg].clone(); - let precision_value = state.registers.get(start_reg + 1).cloned(); - let result = exec_round(®_value, precision_value); - state.registers[*dest] = result; - state.pc += 1; - } - Func::Scalar(ScalarFunc::Min) => { - let start_reg = *start_reg; - let reg_values = state.registers[start_reg..state.registers.len()] - .iter() - .collect(); - let min_fn = |a, b| if a < b { a } else { b }; - if let Some(value) = exec_minmax(reg_values, min_fn) { - state.registers[*dest] = value; - } else { - state.registers[*dest] = OwnedValue::Null; - } - state.pc += 1; - } - Func::Scalar(ScalarFunc::Max) => { - let start_reg = *start_reg; - let reg_values = state.registers[start_reg..state.registers.len()] - .iter() - .collect(); - let max_fn = |a, b| if a > b { a } else { b }; - if let Some(value) = exec_minmax(reg_values, max_fn) { - state.registers[*dest] = value; - } else { - state.registers[*dest] = OwnedValue::Null; - } - state.pc += 1; - } - Func::Scalar(ScalarFunc::Nullif) => { - let start_reg = *start_reg; - let first_value = &state.registers[start_reg + 1]; - let second_value = &state.registers[start_reg + 2]; - state.registers[*dest] = exec_nullif(first_value, second_value); - state.pc += 1; - } - Func::Scalar(ScalarFunc::Substr) | Func::Scalar(ScalarFunc::Substring) => { - let start_reg = *start_reg; - let str_value = &state.registers[start_reg]; - let start_value = &state.registers[start_reg + 1]; - let length_value = &state.registers[start_reg + 2]; - - let result = exec_substring(str_value, start_value, length_value); - state.registers[*dest] = result; - state.pc += 1; - } - Func::Scalar(ScalarFunc::Date) => { - let result = exec_date(&state.registers[*start_reg..]); - state.registers[*dest] = result; - state.pc += 1; - } - Func::Scalar(ScalarFunc::Time) => { - let result = exec_time(&state.registers[*start_reg..]); - state.registers[*dest] = result; - state.pc += 1; - } - Func::Scalar(ScalarFunc::UnixEpoch) => { - if *start_reg == 0 { - let unixepoch: String = - exec_unixepoch(&OwnedValue::Text(Rc::new("now".to_string())))?; - state.registers[*dest] = OwnedValue::Text(Rc::new(unixepoch)); - } else { - let datetime_value = &state.registers[*start_reg]; - let unixepoch = exec_unixepoch(datetime_value); - match unixepoch { - Ok(time) => { - state.registers[*dest] = OwnedValue::Text(Rc::new(time)) - } - Err(e) => { - return Err(LimboError::ParseError(format!( - "Error encountered while parsing datetime value: {}", - e - ))); + ScalarFunc::Random => { + state.registers[*dest] = exec_random(); + } + ScalarFunc::Trim => { + let reg_value = state.registers[*start_reg].clone(); + let pattern_value = state.registers.get(*start_reg + 1).cloned(); + let result = exec_trim(®_value, pattern_value); + state.registers[*dest] = result; + } + ScalarFunc::LTrim => { + let reg_value = state.registers[*start_reg].clone(); + let pattern_value = state.registers.get(*start_reg + 1).cloned(); + let result = exec_ltrim(®_value, pattern_value); + state.registers[*dest] = result; + } + ScalarFunc::RTrim => { + let reg_value = state.registers[*start_reg].clone(); + let pattern_value = state.registers.get(*start_reg + 1).cloned(); + let result = exec_rtrim(®_value, pattern_value); + state.registers[*dest] = result; + } + ScalarFunc::Round => { + let reg_value = state.registers[*start_reg].clone(); + let precision_value = state.registers.get(*start_reg + 1).cloned(); + let result = exec_round(®_value, precision_value); + state.registers[*dest] = result; + } + ScalarFunc::Min => { + let reg_values = state.registers + [*start_reg..*start_reg + arg_count] + .iter() + .collect(); + let min_fn = |a, b| if a < b { a } else { b }; + if let Some(value) = exec_minmax(reg_values, min_fn) { + state.registers[*dest] = value; + } else { + state.registers[*dest] = OwnedValue::Null; } } + ScalarFunc::Max => { + let reg_values = state.registers + [*start_reg..*start_reg + arg_count] + .iter() + .collect(); + let max_fn = |a, b| if a > b { a } else { b }; + if let Some(value) = exec_minmax(reg_values, max_fn) { + state.registers[*dest] = value; + } else { + state.registers[*dest] = OwnedValue::Null; + } + } + ScalarFunc::Nullif => { + let first_value = &state.registers[*start_reg]; + let second_value = &state.registers[*start_reg + 1]; + state.registers[*dest] = exec_nullif(first_value, second_value); + } + ScalarFunc::Substr | ScalarFunc::Substring => { + let str_value = &state.registers[*start_reg]; + let start_value = &state.registers[*start_reg + 1]; + let length_value = &state.registers[*start_reg + 2]; + let result = exec_substring(str_value, start_value, length_value); + state.registers[*dest] = result; + } + ScalarFunc::Date => { + let result = + exec_date(&state.registers[*start_reg..*start_reg + arg_count]); + state.registers[*dest] = result; + } + ScalarFunc::Time => { + let result = + exec_time(&state.registers[*start_reg..*start_reg + arg_count]); + state.registers[*dest] = result; + } + ScalarFunc::UnixEpoch => { + if *start_reg == 0 { + let unixepoch: String = exec_unixepoch(&OwnedValue::Text( + Rc::new("now".to_string()), + ))?; + state.registers[*dest] = OwnedValue::Text(Rc::new(unixepoch)); + } else { + let datetime_value = &state.registers[*start_reg]; + let unixepoch = exec_unixepoch(datetime_value); + match unixepoch { + Ok(time) => { + state.registers[*dest] = OwnedValue::Text(Rc::new(time)) + } + Err(e) => { + return Err(LimboError::ParseError(format!( + "Error encountered while parsing datetime value: {}", + e + ))); + } + } + } + } + }, + crate::function::Func::Agg(_) => { + unreachable!("Aggregate functions should not be handled here") } - state.pc += 1 } - Func::Scalar(ScalarFunc::Unicode) => { - let reg_value = state.registers[*start_reg].borrow_mut(); - state.registers[*dest] = exec_unicode(reg_value); - state.pc += 1; - } - Func::Scalar(ScalarFunc::Quote) => { - let reg_value = state.registers[*start_reg].borrow_mut(); - state.registers[*dest] = exec_quote(reg_value); - state.pc += 1; - } - }, + state.pc += 1; + } Insn::InitCoroutine { yield_reg, jump_on_definition, diff --git a/testing/scalar-functions.test b/testing/scalar-functions.test index da2799799..d1a6699e1 100755 --- a/testing/scalar-functions.test +++ b/testing/scalar-functions.test @@ -255,6 +255,10 @@ do_execsql_test length-empty-text { SELECT length(''); } {0} +do_execsql_test length-date-binary-expr { + select length(date('now')) = 10; +} {1} + do_execsql_test min-number { select min(-10,2,3) } {-10}