Merge 'Add implementation and tests for replace scalar function' from Alperen Keleş

Adds `replace` scalar function.

Reviewed-by: Jussi Saurio <jussi.saurio@gmail.com>

Closes #446
This commit is contained in:
Pekka Enberg
2024-12-13 11:02:08 +02:00
5 changed files with 212 additions and 2 deletions

View File

@@ -133,7 +133,7 @@ Feature support of [sqlite expr syntax](https://www.sqlite.org/lang_expr.html).
| quote(X) | Yes | |
| random() | Yes | |
| randomblob(N) | Yes | |
| replace(X,Y,Z) | No | |
| replace(X,Y,Z) | Yes | |
| round(X) | Yes | |
| round(X,Y) | Yes | |
| rtrim(X) | Yes | |

View File

@@ -88,6 +88,7 @@ pub enum ScalarFunc {
Unhex,
ZeroBlob,
LastInsertRowid,
Replace,
}
impl Display for ScalarFunc {
@@ -132,6 +133,7 @@ impl Display for ScalarFunc {
ScalarFunc::Unhex => "unhex".to_string(),
ScalarFunc::ZeroBlob => "zeroblob".to_string(),
ScalarFunc::LastInsertRowid => "last_insert_rowid".to_string(),
ScalarFunc::Replace => "replace".to_string(),
};
write!(f, "{}", str)
}
@@ -206,6 +208,7 @@ impl Func {
"unicode" => Ok(Func::Scalar(ScalarFunc::Unicode)),
"quote" => Ok(Func::Scalar(ScalarFunc::Quote)),
"sqlite_version" => Ok(Func::Scalar(ScalarFunc::SqliteVersion)),
"replace" => Ok(Func::Scalar(ScalarFunc::Replace)),
#[cfg(feature = "json")]
"json" => Ok(Func::Json(JsonFunc::Json)),
"unixepoch" => Ok(Func::Scalar(ScalarFunc::UnixEpoch)),

View File

@@ -1550,6 +1550,57 @@ pub fn translate_expr(
});
Ok(target_register)
}
ScalarFunc::Replace => {
let args = if let Some(args) = args {
if !args.len() == 3 {
crate::bail_parse_error!(
"function {}() requires exactly 3 arguments",
srf.to_string()
)
}
args
} else {
crate::bail_parse_error!(
"function {}() requires exactly 3 arguments",
srf.to_string()
);
};
let str_reg = program.alloc_register();
let pattern_reg = program.alloc_register();
let replacement_reg = program.alloc_register();
translate_expr(
program,
referenced_tables,
&args[0],
str_reg,
precomputed_exprs_to_registers,
)?;
translate_expr(
program,
referenced_tables,
&args[1],
pattern_reg,
precomputed_exprs_to_registers,
)?;
translate_expr(
program,
referenced_tables,
&args[2],
replacement_reg,
precomputed_exprs_to_registers,
)?;
program.emit_insn(Insn::Function {
constant_mask: 0,
start_reg: str_reg,
dest: target_register,
func: func_ctx,
});
Ok(target_register)
}
}
}
}

View File

@@ -2483,6 +2483,13 @@ impl Program {
let version = execute_sqlite_version(version_integer);
state.registers[*dest] = OwnedValue::Text(Rc::new(version));
}
ScalarFunc::Replace => {
assert!(arg_count == 3);
let source = &state.registers[*start_reg];
let pattern = &state.registers[*start_reg + 1];
let replacement = &state.registers[*start_reg + 2];
state.registers[*dest] = exec_replace(source, pattern, replacement);
}
},
crate::function::Func::Agg(_) => {
unreachable!("Aggregate functions should not be handled here")
@@ -3422,6 +3429,37 @@ fn exec_cast(value: &OwnedValue, datatype: &str) -> OwnedValue {
}
}
fn exec_replace(source: &OwnedValue, pattern: &OwnedValue, replacement: &OwnedValue) -> OwnedValue {
// The replace(X,Y,Z) function returns a string formed by substituting string Z for every occurrence of
// string Y in string X. The BINARY collating sequence is used for comparisons. If Y is an empty string
// then return X unchanged. If Z is not initially a string, it is cast to a UTF-8 string prior to processing.
// If any of the arguments is NULL, the result is NULL.
if matches!(source, OwnedValue::Null)
|| matches!(pattern, OwnedValue::Null)
|| matches!(replacement, OwnedValue::Null)
{
return OwnedValue::Null;
}
let source = exec_cast(source, "TEXT");
let pattern = exec_cast(pattern, "TEXT");
let replacement = exec_cast(replacement, "TEXT");
// If any of the casts failed, panic as text casting is not expected to fail.
match (&source, &pattern, &replacement) {
(OwnedValue::Text(source), OwnedValue::Text(pattern), OwnedValue::Text(replacement)) => {
if pattern.is_empty() {
return OwnedValue::Text(source.clone());
}
let result = source.replace(pattern.as_str(), replacement);
OwnedValue::Text(Rc::new(result))
}
_ => unreachable!("text cast should never fail"),
}
}
enum Affinity {
Integer,
Text,
@@ -3549,7 +3587,10 @@ fn execute_sqlite_version(version_integer: i64) -> String {
#[cfg(test)]
mod tests {
use crate::types::{SeekKey, SeekOp};
use crate::{
types::{SeekKey, SeekOp},
vdbe::exec_replace,
};
use super::{
exec_abs, exec_char, exec_hex, exec_if, exec_instr, exec_length, exec_like, exec_lower,
@@ -4495,4 +4536,107 @@ mod tests {
let expected = "3.46.1";
assert_eq!(execute_sqlite_version(version_integer), expected);
}
#[test]
fn test_replace() {
let input_str = OwnedValue::Text(Rc::new(String::from("bob")));
let pattern_str = OwnedValue::Text(Rc::new(String::from("b")));
let replace_str = OwnedValue::Text(Rc::new(String::from("a")));
let expected_str = OwnedValue::Text(Rc::new(String::from("aoa")));
assert_eq!(
exec_replace(&input_str, &pattern_str, &replace_str),
expected_str
);
let input_str = OwnedValue::Text(Rc::new(String::from("bob")));
let pattern_str = OwnedValue::Text(Rc::new(String::from("b")));
let replace_str = OwnedValue::Text(Rc::new(String::from("")));
let expected_str = OwnedValue::Text(Rc::new(String::from("o")));
assert_eq!(
exec_replace(&input_str, &pattern_str, &replace_str),
expected_str
);
let input_str = OwnedValue::Text(Rc::new(String::from("bob")));
let pattern_str = OwnedValue::Text(Rc::new(String::from("b")));
let replace_str = OwnedValue::Text(Rc::new(String::from("abc")));
let expected_str = OwnedValue::Text(Rc::new(String::from("abcoabc")));
assert_eq!(
exec_replace(&input_str, &pattern_str, &replace_str),
expected_str
);
let input_str = OwnedValue::Text(Rc::new(String::from("bob")));
let pattern_str = OwnedValue::Text(Rc::new(String::from("a")));
let replace_str = OwnedValue::Text(Rc::new(String::from("b")));
let expected_str = OwnedValue::Text(Rc::new(String::from("bob")));
assert_eq!(
exec_replace(&input_str, &pattern_str, &replace_str),
expected_str
);
let input_str = OwnedValue::Text(Rc::new(String::from("bob")));
let pattern_str = OwnedValue::Text(Rc::new(String::from("")));
let replace_str = OwnedValue::Text(Rc::new(String::from("a")));
let expected_str = OwnedValue::Text(Rc::new(String::from("bob")));
assert_eq!(
exec_replace(&input_str, &pattern_str, &replace_str),
expected_str
);
let input_str = OwnedValue::Text(Rc::new(String::from("bob")));
let pattern_str = OwnedValue::Null;
let replace_str = OwnedValue::Text(Rc::new(String::from("a")));
let expected_str = OwnedValue::Null;
assert_eq!(
exec_replace(&input_str, &pattern_str, &replace_str),
expected_str
);
let input_str = OwnedValue::Text(Rc::new(String::from("bo5")));
let pattern_str = OwnedValue::Integer(5);
let replace_str = OwnedValue::Text(Rc::new(String::from("a")));
let expected_str = OwnedValue::Text(Rc::new(String::from("boa")));
assert_eq!(
exec_replace(&input_str, &pattern_str, &replace_str),
expected_str
);
let input_str = OwnedValue::Text(Rc::new(String::from("bo5.0")));
let pattern_str = OwnedValue::Float(5.0);
let replace_str = OwnedValue::Text(Rc::new(String::from("a")));
let expected_str = OwnedValue::Text(Rc::new(String::from("boa")));
assert_eq!(
exec_replace(&input_str, &pattern_str, &replace_str),
expected_str
);
let input_str = OwnedValue::Text(Rc::new(String::from("bo5")));
let pattern_str = OwnedValue::Float(5.0);
let replace_str = OwnedValue::Text(Rc::new(String::from("a")));
let expected_str = OwnedValue::Text(Rc::new(String::from("bo5")));
assert_eq!(
exec_replace(&input_str, &pattern_str, &replace_str),
expected_str
);
let input_str = OwnedValue::Text(Rc::new(String::from("bo5.0")));
let pattern_str = OwnedValue::Float(5.0);
let replace_str = OwnedValue::Float(6.0);
let expected_str = OwnedValue::Text(Rc::new(String::from("bo6.0")));
assert_eq!(
exec_replace(&input_str, &pattern_str, &replace_str),
expected_str
);
// todo: change this test to use (0.1 + 0.2) instead of 0.3 when decimals are implemented.
let input_str = OwnedValue::Text(Rc::new(String::from("tes3")));
let pattern_str = OwnedValue::Integer(3);
let replace_str = OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Float(0.3))));
let expected_str = OwnedValue::Text(Rc::new(String::from("tes0.3")));
assert_eq!(
exec_replace(&input_str, &pattern_str, &replace_str),
expected_str
);
}
}

View File

@@ -167,6 +167,18 @@ do_execsql_test lower-null {
select lower(null)
} {}
do_execsql_test replace {
select replace('test', 'test', 'example')
} {example}
do_execsql_test replace-number {
select replace('tes3', 3, 0.3)
} {tes0.3}
do_execsql_test replace-null {
select replace('test', null, 'example')
} {}
do_execsql_test hex {
select hex('limbo')
} {6C696D626F}