Merge 'Pass FuncCtx to Insn::Function to keep track of arg count' from Jussi Saurio

https://github.com/penberg/limbo/pull/321 recently fixed an error with
register allocation order concerning binary expressions in
`translate_condition_expr`. I noticed we had the same bug in
`translate_expr` as well, and this led into noticing that our
implementation of `Insn::Function` diverges from SQLite in that it
doesn't provide a context object that keeps track of how many arguments
the function call has.
Not having the arg count available at runtime means that e.g.:
```
do_execsql_test length-date-binary-expr {
  select length(date('now')) = 10;
} {1}
```
Returned 0 instead of 1, because at runtime another register was already
allocated, and `date()` was implemented like this:
```
                    Func::Scalar(ScalarFunc::Date) => {
                        let result = exec_date(&state.registers[*start_reg..]); // all registers starting from start_reg
                        state.registers[*dest] = result;
                        state.pc += 1;
                    }
```
SQlite bytecode engine docs for Function say:
> Invoke a user function (P4 is a pointer to an sqlite3_context object
that contains a pointer to the function to be run) with arguments taken
from register P2 and successors. The number of arguments is in the
sqlite3_context object that P4 points to.
Accordingly, this PR implements a `FuncCtx` struct that contains a
`Func` and `arg_count` (usize).

Reviewed-by: Pere Diaz Bou <pere-altea@hotmail.com>

Closes #323
This commit is contained in:
jussisaurio
2024-09-14 16:12:25 +03:00
5 changed files with 220 additions and 235 deletions

View File

@@ -104,6 +104,22 @@ pub enum Func {
Json(JsonFunc),
}
impl Func {
pub fn to_string(&self) -> String {
match self {
Func::Agg(agg_func) => agg_func.to_string().to_string(),
Func::Scalar(scalar_func) => scalar_func.to_string(),
Func::Json(json_func) => json_func.to_string(),
}
}
}
#[derive(Debug)]
pub struct FuncCtx {
pub func: Func,
pub arg_count: usize,
}
impl Func {
pub fn resolve_function(name: &str, arg_count: usize) -> Result<Func, ()> {
match name {

View File

@@ -2,7 +2,7 @@ use crate::{function::JsonFunc, Result};
use sqlite3_parser::ast::{self, UnaryOperator};
use std::rc::Rc;
use crate::function::{AggFunc, Func, ScalarFunc};
use crate::function::{AggFunc, Func, FuncCtx, ScalarFunc};
use crate::schema::Type;
use crate::util::normalize_ident;
use crate::{
@@ -457,9 +457,12 @@ pub fn translate_condition_expr(
// Only constant patterns for LIKE are supported currently, so this
// is always 1
constant_mask: 1,
func: crate::vdbe::Func::Scalar(ScalarFunc::Like),
start_reg: pattern_reg,
dest: cur_reg,
func: FuncCtx {
func: Func::Scalar(ScalarFunc::Like),
arg_count: 2,
},
});
}
ast::LikeOperator::Glob => todo!(),
@@ -524,8 +527,8 @@ pub fn translate_expr(
ast::Expr::Between { .. } => todo!(),
ast::Expr::Binary(e1, op, e2) => {
let e1_reg = program.alloc_register();
let e2_reg = program.alloc_register();
let _ = translate_expr(program, referenced_tables, e1, e1_reg, cursor_hint)?;
let e2_reg = program.alloc_register();
let _ = translate_expr(program, referenced_tables, e2, e2_reg, cursor_hint)?;
match op {
@@ -634,12 +637,20 @@ pub fn translate_expr(
let func_type: Option<Func> =
Func::resolve_function(normalize_ident(name.0.as_str()).as_str(), args_count).ok();
match func_type {
Some(Func::Agg(_)) => {
if func_type.is_none() {
crate::bail_parse_error!("unknown function {}", name.0);
}
let func_ctx = FuncCtx {
func: func_type.unwrap(),
arg_count: args_count,
};
match &func_ctx.func {
Func::Agg(_) => {
crate::bail_parse_error!("aggregation function in non-aggregation context")
}
Some(Func::Json(j)) => match j {
Func::Json(j) => match j {
JsonFunc::JSON => {
let args = if let Some(args) = args {
if args.len() != 1 {
@@ -661,12 +672,12 @@ pub fn translate_expr(
constant_mask: 0,
start_reg: regs,
dest: target_register,
func: crate::vdbe::Func::Json(j),
func: func_ctx,
});
Ok(target_register)
}
},
Some(Func::Scalar(srf)) => {
Func::Scalar(srf) => {
match srf {
ScalarFunc::Char => {
let args = args.clone().unwrap_or_else(Vec::new);
@@ -680,7 +691,7 @@ pub fn translate_expr(
constant_mask: 0,
start_reg: target_register + 1,
dest: target_register,
func: crate::vdbe::Func::Scalar(srf),
func: func_ctx,
});
Ok(target_register)
}
@@ -742,7 +753,7 @@ pub fn translate_expr(
constant_mask: 0,
start_reg: target_register + 1,
dest: target_register,
func: crate::vdbe::Func::Scalar(srf),
func: func_ctx,
});
Ok(target_register)
}
@@ -822,7 +833,7 @@ pub fn translate_expr(
constant_mask: 1,
start_reg: target_register + 1,
dest: target_register,
func: crate::vdbe::Func::Scalar(srf),
func: func_ctx,
});
Ok(target_register)
}
@@ -859,7 +870,7 @@ pub fn translate_expr(
constant_mask: 0,
start_reg: regs,
dest: target_register,
func: crate::vdbe::Func::Scalar(srf),
func: func_ctx,
});
Ok(target_register)
}
@@ -875,7 +886,7 @@ pub fn translate_expr(
constant_mask: 0,
start_reg: regs,
dest: target_register,
func: crate::vdbe::Func::Scalar(srf),
func: func_ctx,
});
Ok(target_register)
}
@@ -897,7 +908,7 @@ pub fn translate_expr(
constant_mask: 0,
start_reg: target_register + 1,
dest: target_register,
func: crate::vdbe::Func::Scalar(ScalarFunc::Date),
func: func_ctx,
});
Ok(target_register)
}
@@ -949,7 +960,7 @@ pub fn translate_expr(
constant_mask: 0,
start_reg: str_reg,
dest: target_register,
func: crate::vdbe::Func::Scalar(ScalarFunc::Substring),
func: func_ctx,
});
Ok(target_register)
}
@@ -974,7 +985,7 @@ pub fn translate_expr(
constant_mask: 0,
start_reg,
dest: target_register,
func: crate::vdbe::Func::Scalar(srf),
func: func_ctx,
});
Ok(target_register)
}
@@ -996,7 +1007,7 @@ pub fn translate_expr(
constant_mask: 0,
start_reg: target_register + 1,
dest: target_register,
func: crate::vdbe::Func::Scalar(ScalarFunc::Time),
func: func_ctx,
});
Ok(target_register)
}
@@ -1030,7 +1041,7 @@ pub fn translate_expr(
constant_mask: 0,
start_reg: target_register + 1,
dest: target_register,
func: crate::vdbe::Func::Scalar(srf),
func: func_ctx,
});
Ok(target_register)
}
@@ -1064,7 +1075,7 @@ pub fn translate_expr(
constant_mask: 0,
start_reg: target_register + 1,
dest: target_register,
func: crate::vdbe::Func::Scalar(ScalarFunc::Min),
func: func_ctx,
});
Ok(target_register)
}
@@ -1098,7 +1109,7 @@ pub fn translate_expr(
constant_mask: 0,
start_reg: target_register + 1,
dest: target_register,
func: crate::vdbe::Func::Scalar(ScalarFunc::Max),
func: func_ctx,
});
Ok(target_register)
}
@@ -1114,7 +1125,6 @@ pub fn translate_expr(
crate::bail_parse_error!("nullif function with no arguments");
};
let func_reg = program.alloc_register();
let first_reg = program.alloc_register();
translate_expr(
program,
@@ -1133,18 +1143,15 @@ pub fn translate_expr(
)?;
program.emit_insn(Insn::Function {
constant_mask: 0,
start_reg: func_reg,
start_reg: first_reg,
dest: target_register,
func: crate::vdbe::Func::Scalar(srf),
func: func_ctx,
});
Ok(target_register)
}
}
}
None => {
crate::bail_parse_error!("unknown function {}", name.0);
}
}
}
ast::Expr::FunctionCallStar { .. } => todo!(),
@@ -1528,7 +1535,7 @@ pub fn translate_aggregation(
expr,
expr_reg,
cursor_hint,
);
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
@@ -1549,7 +1556,7 @@ pub fn translate_aggregation(
expr,
expr_reg,
cursor_hint,
);
)?;
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,

View File

@@ -542,9 +542,14 @@ pub fn insn_to_str(
*constant_mask,
*start_reg as i32,
*dest as i32,
OwnedValue::Text(Rc::new(func.to_string())),
OwnedValue::Text(Rc::new(func.func.to_string())),
0,
format!("r[{}]=func(r[{}..])", dest, start_reg),
format!(
"r[{}]=func(r[{}..{}])",
dest,
start_reg,
start_reg + func.arg_count - 1
),
),
Insn::InitCoroutine {
yield_reg,

View File

@@ -24,7 +24,7 @@ pub mod sorter;
mod datetime;
use crate::error::LimboError;
use crate::function::{AggFunc, JsonFunc, ScalarFunc};
use crate::function::{AggFunc, FuncCtx, JsonFunc, ScalarFunc};
use crate::json::get_json;
use crate::pseudo::PseudoCursor;
use crate::schema::Table;
@@ -305,7 +305,7 @@ pub enum Insn {
constant_mask: i32, // P1
start_reg: usize, // P2, start of argument registers
dest: usize, // P3
func: Func, // P4
func: FuncCtx, // P4
},
InitCoroutine {
@@ -1179,216 +1179,169 @@ impl Program {
func,
start_reg,
dest,
} => match func {
Func::Json(JsonFunc::JSON) => {
let json_value = &state.registers[*start_reg];
let json_str = get_json(json_value);
match json_str {
Ok(json) => state.registers[*dest] = json,
Err(e) => return Err(e),
} => {
let arg_count = func.arg_count;
match &func.func {
crate::function::Func::Json(JsonFunc::JSON) => {
let json_value = &state.registers[*start_reg];
let json_str = get_json(json_value);
match json_str {
Ok(json) => state.registers[*dest] = json,
Err(e) => return Err(e),
}
}
state.pc += 1;
}
Func::Scalar(ScalarFunc::Char) => {
let start_reg = *start_reg;
let reg_values = state.registers[start_reg..state.registers.len()].to_vec();
state.registers[*dest] = exec_char(reg_values);
state.pc += 1;
}
Func::Scalar(ScalarFunc::Coalesce) => {}
Func::Scalar(ScalarFunc::Concat) => {
let start_reg = *start_reg;
let result = exec_concat(&state.registers[start_reg..]);
state.registers[*dest] = result;
state.pc += 1;
}
Func::Scalar(ScalarFunc::IfNull) => {}
Func::Scalar(ScalarFunc::Like) => {
let start_reg = *start_reg;
assert!(
start_reg + 2 <= state.registers.len(),
"not enough registers {} < {}",
start_reg,
state.registers.len()
);
let pattern = &state.registers[start_reg];
let text = &state.registers[start_reg + 1];
let result = match (pattern, text) {
(OwnedValue::Text(pattern), OwnedValue::Text(text)) => {
let cache = if *constant_mask > 0 {
Some(&mut state.regex_cache)
} else {
None
crate::function::Func::Scalar(scalar_func) => match scalar_func {
ScalarFunc::Char => {
let reg_values =
state.registers[*start_reg..*start_reg + arg_count].to_vec();
state.registers[*dest] = exec_char(reg_values);
}
ScalarFunc::Coalesce => {}
ScalarFunc::Concat => {
let result = exec_concat(
&state.registers[*start_reg..*start_reg + arg_count],
);
state.registers[*dest] = result;
}
ScalarFunc::IfNull => {}
ScalarFunc::Like => {
let pattern = &state.registers[*start_reg];
let text = &state.registers[*start_reg + 1];
let result = match (pattern, text) {
(OwnedValue::Text(pattern), OwnedValue::Text(text)) => {
let cache = if *constant_mask > 0 {
Some(&mut state.regex_cache)
} else {
None
};
OwnedValue::Integer(exec_like(cache, pattern, text) as i64)
}
_ => {
unreachable!("Like on non-text registers");
}
};
OwnedValue::Integer(exec_like(cache, pattern, text) as i64)
state.registers[*dest] = result;
}
_ => {
unreachable!("Like on non-text registers");
ScalarFunc::Abs
| ScalarFunc::Lower
| ScalarFunc::Upper
| ScalarFunc::Length
| ScalarFunc::Unicode
| ScalarFunc::Quote => {
let reg_value = state.registers[*start_reg].borrow_mut();
let result = match scalar_func {
ScalarFunc::Abs => exec_abs(reg_value),
ScalarFunc::Lower => exec_lower(reg_value),
ScalarFunc::Upper => exec_upper(reg_value),
ScalarFunc::Length => Some(exec_length(reg_value)),
ScalarFunc::Unicode => Some(exec_unicode(reg_value)),
ScalarFunc::Quote => Some(exec_quote(reg_value)),
_ => unreachable!(),
};
state.registers[*dest] = result.unwrap_or(OwnedValue::Null);
}
};
state.registers[*dest] = result;
state.pc += 1;
}
Func::Scalar(ScalarFunc::Abs) => {
let reg_value = state.registers[*start_reg].borrow_mut();
if let Some(value) = exec_abs(reg_value) {
state.registers[*dest] = value;
} else {
state.registers[*dest] = OwnedValue::Null;
}
state.pc += 1;
}
Func::Scalar(ScalarFunc::Upper) => {
let reg_value = state.registers[*start_reg].borrow_mut();
if let Some(value) = exec_upper(reg_value) {
state.registers[*dest] = value;
} else {
state.registers[*dest] = OwnedValue::Null;
}
state.pc += 1;
}
Func::Scalar(ScalarFunc::Lower) => {
let reg_value = state.registers[*start_reg].borrow_mut();
if let Some(value) = exec_lower(reg_value) {
state.registers[*dest] = value;
} else {
state.registers[*dest] = OwnedValue::Null;
}
state.pc += 1;
}
Func::Scalar(ScalarFunc::Length) => {
let reg_value = state.registers[*start_reg].borrow_mut();
state.registers[*dest] = exec_length(reg_value);
state.pc += 1;
}
Func::Scalar(ScalarFunc::Random) => {
state.registers[*dest] = exec_random();
state.pc += 1;
}
Func::Scalar(ScalarFunc::Trim) => {
let start_reg = *start_reg;
let reg_value = state.registers[start_reg].clone();
let pattern_value = state.registers.get(start_reg + 1).cloned();
let result = exec_trim(&reg_value, pattern_value);
state.registers[*dest] = result;
state.pc += 1;
}
Func::Scalar(ScalarFunc::LTrim) => {
let start_reg = *start_reg;
let reg_value = state.registers[start_reg].clone();
let pattern_value = state.registers.get(start_reg + 1).cloned();
let result = exec_ltrim(&reg_value, pattern_value);
state.registers[*dest] = result;
state.pc += 1;
}
Func::Scalar(ScalarFunc::RTrim) => {
let start_reg = *start_reg;
let reg_value = state.registers[start_reg].clone();
let pattern_value = state.registers.get(start_reg + 1).cloned();
let result = exec_rtrim(&reg_value, pattern_value);
state.registers[*dest] = result;
state.pc += 1;
}
Func::Scalar(ScalarFunc::Round) => {
let start_reg = *start_reg;
let reg_value = state.registers[start_reg].clone();
let precision_value = state.registers.get(start_reg + 1).cloned();
let result = exec_round(&reg_value, precision_value);
state.registers[*dest] = result;
state.pc += 1;
}
Func::Scalar(ScalarFunc::Min) => {
let start_reg = *start_reg;
let reg_values = state.registers[start_reg..state.registers.len()]
.iter()
.collect();
let min_fn = |a, b| if a < b { a } else { b };
if let Some(value) = exec_minmax(reg_values, min_fn) {
state.registers[*dest] = value;
} else {
state.registers[*dest] = OwnedValue::Null;
}
state.pc += 1;
}
Func::Scalar(ScalarFunc::Max) => {
let start_reg = *start_reg;
let reg_values = state.registers[start_reg..state.registers.len()]
.iter()
.collect();
let max_fn = |a, b| if a > b { a } else { b };
if let Some(value) = exec_minmax(reg_values, max_fn) {
state.registers[*dest] = value;
} else {
state.registers[*dest] = OwnedValue::Null;
}
state.pc += 1;
}
Func::Scalar(ScalarFunc::Nullif) => {
let start_reg = *start_reg;
let first_value = &state.registers[start_reg + 1];
let second_value = &state.registers[start_reg + 2];
state.registers[*dest] = exec_nullif(first_value, second_value);
state.pc += 1;
}
Func::Scalar(ScalarFunc::Substr) | Func::Scalar(ScalarFunc::Substring) => {
let start_reg = *start_reg;
let str_value = &state.registers[start_reg];
let start_value = &state.registers[start_reg + 1];
let length_value = &state.registers[start_reg + 2];
let result = exec_substring(str_value, start_value, length_value);
state.registers[*dest] = result;
state.pc += 1;
}
Func::Scalar(ScalarFunc::Date) => {
let result = exec_date(&state.registers[*start_reg..]);
state.registers[*dest] = result;
state.pc += 1;
}
Func::Scalar(ScalarFunc::Time) => {
let result = exec_time(&state.registers[*start_reg..]);
state.registers[*dest] = result;
state.pc += 1;
}
Func::Scalar(ScalarFunc::UnixEpoch) => {
if *start_reg == 0 {
let unixepoch: String =
exec_unixepoch(&OwnedValue::Text(Rc::new("now".to_string())))?;
state.registers[*dest] = OwnedValue::Text(Rc::new(unixepoch));
} else {
let datetime_value = &state.registers[*start_reg];
let unixepoch = exec_unixepoch(datetime_value);
match unixepoch {
Ok(time) => {
state.registers[*dest] = OwnedValue::Text(Rc::new(time))
}
Err(e) => {
return Err(LimboError::ParseError(format!(
"Error encountered while parsing datetime value: {}",
e
)));
ScalarFunc::Random => {
state.registers[*dest] = exec_random();
}
ScalarFunc::Trim => {
let reg_value = state.registers[*start_reg].clone();
let pattern_value = state.registers.get(*start_reg + 1).cloned();
let result = exec_trim(&reg_value, pattern_value);
state.registers[*dest] = result;
}
ScalarFunc::LTrim => {
let reg_value = state.registers[*start_reg].clone();
let pattern_value = state.registers.get(*start_reg + 1).cloned();
let result = exec_ltrim(&reg_value, pattern_value);
state.registers[*dest] = result;
}
ScalarFunc::RTrim => {
let reg_value = state.registers[*start_reg].clone();
let pattern_value = state.registers.get(*start_reg + 1).cloned();
let result = exec_rtrim(&reg_value, pattern_value);
state.registers[*dest] = result;
}
ScalarFunc::Round => {
let reg_value = state.registers[*start_reg].clone();
let precision_value = state.registers.get(*start_reg + 1).cloned();
let result = exec_round(&reg_value, precision_value);
state.registers[*dest] = result;
}
ScalarFunc::Min => {
let reg_values = state.registers
[*start_reg..*start_reg + arg_count]
.iter()
.collect();
let min_fn = |a, b| if a < b { a } else { b };
if let Some(value) = exec_minmax(reg_values, min_fn) {
state.registers[*dest] = value;
} else {
state.registers[*dest] = OwnedValue::Null;
}
}
ScalarFunc::Max => {
let reg_values = state.registers
[*start_reg..*start_reg + arg_count]
.iter()
.collect();
let max_fn = |a, b| if a > b { a } else { b };
if let Some(value) = exec_minmax(reg_values, max_fn) {
state.registers[*dest] = value;
} else {
state.registers[*dest] = OwnedValue::Null;
}
}
ScalarFunc::Nullif => {
let first_value = &state.registers[*start_reg];
let second_value = &state.registers[*start_reg + 1];
state.registers[*dest] = exec_nullif(first_value, second_value);
}
ScalarFunc::Substr | ScalarFunc::Substring => {
let str_value = &state.registers[*start_reg];
let start_value = &state.registers[*start_reg + 1];
let length_value = &state.registers[*start_reg + 2];
let result = exec_substring(str_value, start_value, length_value);
state.registers[*dest] = result;
}
ScalarFunc::Date => {
let result =
exec_date(&state.registers[*start_reg..*start_reg + arg_count]);
state.registers[*dest] = result;
}
ScalarFunc::Time => {
let result =
exec_time(&state.registers[*start_reg..*start_reg + arg_count]);
state.registers[*dest] = result;
}
ScalarFunc::UnixEpoch => {
if *start_reg == 0 {
let unixepoch: String = exec_unixepoch(&OwnedValue::Text(
Rc::new("now".to_string()),
))?;
state.registers[*dest] = OwnedValue::Text(Rc::new(unixepoch));
} else {
let datetime_value = &state.registers[*start_reg];
let unixepoch = exec_unixepoch(datetime_value);
match unixepoch {
Ok(time) => {
state.registers[*dest] = OwnedValue::Text(Rc::new(time))
}
Err(e) => {
return Err(LimboError::ParseError(format!(
"Error encountered while parsing datetime value: {}",
e
)));
}
}
}
}
},
crate::function::Func::Agg(_) => {
unreachable!("Aggregate functions should not be handled here")
}
state.pc += 1
}
Func::Scalar(ScalarFunc::Unicode) => {
let reg_value = state.registers[*start_reg].borrow_mut();
state.registers[*dest] = exec_unicode(reg_value);
state.pc += 1;
}
Func::Scalar(ScalarFunc::Quote) => {
let reg_value = state.registers[*start_reg].borrow_mut();
state.registers[*dest] = exec_quote(reg_value);
state.pc += 1;
}
},
state.pc += 1;
}
Insn::InitCoroutine {
yield_reg,
jump_on_definition,

View File

@@ -255,6 +255,10 @@ do_execsql_test length-empty-text {
SELECT length('');
} {0}
do_execsql_test length-date-binary-expr {
select length(date('now')) = 10;
} {1}
do_execsql_test min-number {
select min(-10,2,3)
} {-10}