diff --git a/COMPAT.md b/COMPAT.md index b6db9e455..4c0210257 100644 --- a/COMPAT.md +++ b/COMPAT.md @@ -133,7 +133,7 @@ Feature support of [sqlite expr syntax](https://www.sqlite.org/lang_expr.html). | quote(X) | Yes | | | random() | Yes | | | randomblob(N) | Yes | | -| replace(X,Y,Z) | No | | +| replace(X,Y,Z) | Yes | | | round(X) | Yes | | | round(X,Y) | Yes | | | rtrim(X) | Yes | | diff --git a/core/function.rs b/core/function.rs index 0659da73f..9e2dca0f3 100644 --- a/core/function.rs +++ b/core/function.rs @@ -88,6 +88,7 @@ pub enum ScalarFunc { Unhex, ZeroBlob, LastInsertRowid, + Replace, } impl Display for ScalarFunc { @@ -132,6 +133,7 @@ impl Display for ScalarFunc { ScalarFunc::Unhex => "unhex".to_string(), ScalarFunc::ZeroBlob => "zeroblob".to_string(), ScalarFunc::LastInsertRowid => "last_insert_rowid".to_string(), + ScalarFunc::Replace => "replace".to_string(), }; write!(f, "{}", str) } @@ -206,6 +208,7 @@ impl Func { "unicode" => Ok(Func::Scalar(ScalarFunc::Unicode)), "quote" => Ok(Func::Scalar(ScalarFunc::Quote)), "sqlite_version" => Ok(Func::Scalar(ScalarFunc::SqliteVersion)), + "replace" => Ok(Func::Scalar(ScalarFunc::Replace)), #[cfg(feature = "json")] "json" => Ok(Func::Json(JsonFunc::Json)), "unixepoch" => Ok(Func::Scalar(ScalarFunc::UnixEpoch)), diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 8f4de5c13..c327f547d 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1550,6 +1550,57 @@ pub fn translate_expr( }); Ok(target_register) } + ScalarFunc::Replace => { + let args = if let Some(args) = args { + if !args.len() == 3 { + crate::bail_parse_error!( + "function {}() requires exactly 3 arguments", + srf.to_string() + ) + } + args + } else { + crate::bail_parse_error!( + "function {}() requires exactly 3 arguments", + srf.to_string() + ); + }; + + let str_reg = program.alloc_register(); + let pattern_reg = program.alloc_register(); + let replacement_reg = program.alloc_register(); + + translate_expr( + program, + referenced_tables, + &args[0], + str_reg, + precomputed_exprs_to_registers, + )?; + translate_expr( + program, + referenced_tables, + &args[1], + pattern_reg, + precomputed_exprs_to_registers, + )?; + + translate_expr( + program, + referenced_tables, + &args[2], + replacement_reg, + precomputed_exprs_to_registers, + )?; + + program.emit_insn(Insn::Function { + constant_mask: 0, + start_reg: str_reg, + dest: target_register, + func: func_ctx, + }); + Ok(target_register) + } } } } diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 624284694..e576f1cb7 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -2483,6 +2483,13 @@ impl Program { let version = execute_sqlite_version(version_integer); state.registers[*dest] = OwnedValue::Text(Rc::new(version)); } + ScalarFunc::Replace => { + assert!(arg_count == 3); + let source = &state.registers[*start_reg]; + let pattern = &state.registers[*start_reg + 1]; + let replacement = &state.registers[*start_reg + 2]; + state.registers[*dest] = exec_replace(source, pattern, replacement); + } }, crate::function::Func::Agg(_) => { unreachable!("Aggregate functions should not be handled here") @@ -3422,6 +3429,37 @@ fn exec_cast(value: &OwnedValue, datatype: &str) -> OwnedValue { } } +fn exec_replace(source: &OwnedValue, pattern: &OwnedValue, replacement: &OwnedValue) -> OwnedValue { + // The replace(X,Y,Z) function returns a string formed by substituting string Z for every occurrence of + // string Y in string X. The BINARY collating sequence is used for comparisons. If Y is an empty string + // then return X unchanged. If Z is not initially a string, it is cast to a UTF-8 string prior to processing. + + // If any of the arguments is NULL, the result is NULL. + if matches!(source, OwnedValue::Null) + || matches!(pattern, OwnedValue::Null) + || matches!(replacement, OwnedValue::Null) + { + return OwnedValue::Null; + } + + let source = exec_cast(source, "TEXT"); + let pattern = exec_cast(pattern, "TEXT"); + let replacement = exec_cast(replacement, "TEXT"); + + // If any of the casts failed, panic as text casting is not expected to fail. + match (&source, &pattern, &replacement) { + (OwnedValue::Text(source), OwnedValue::Text(pattern), OwnedValue::Text(replacement)) => { + if pattern.is_empty() { + return OwnedValue::Text(source.clone()); + } + + let result = source.replace(pattern.as_str(), replacement); + OwnedValue::Text(Rc::new(result)) + } + _ => unreachable!("text cast should never fail"), + } +} + enum Affinity { Integer, Text, @@ -3549,7 +3587,10 @@ fn execute_sqlite_version(version_integer: i64) -> String { #[cfg(test)] mod tests { - use crate::types::{SeekKey, SeekOp}; + use crate::{ + types::{SeekKey, SeekOp}, + vdbe::exec_replace, + }; use super::{ exec_abs, exec_char, exec_hex, exec_if, exec_instr, exec_length, exec_like, exec_lower, @@ -4495,4 +4536,107 @@ mod tests { let expected = "3.46.1"; assert_eq!(execute_sqlite_version(version_integer), expected); } + + #[test] + fn test_replace() { + let input_str = OwnedValue::Text(Rc::new(String::from("bob"))); + let pattern_str = OwnedValue::Text(Rc::new(String::from("b"))); + let replace_str = OwnedValue::Text(Rc::new(String::from("a"))); + let expected_str = OwnedValue::Text(Rc::new(String::from("aoa"))); + assert_eq!( + exec_replace(&input_str, &pattern_str, &replace_str), + expected_str + ); + + let input_str = OwnedValue::Text(Rc::new(String::from("bob"))); + let pattern_str = OwnedValue::Text(Rc::new(String::from("b"))); + let replace_str = OwnedValue::Text(Rc::new(String::from(""))); + let expected_str = OwnedValue::Text(Rc::new(String::from("o"))); + assert_eq!( + exec_replace(&input_str, &pattern_str, &replace_str), + expected_str + ); + + let input_str = OwnedValue::Text(Rc::new(String::from("bob"))); + let pattern_str = OwnedValue::Text(Rc::new(String::from("b"))); + let replace_str = OwnedValue::Text(Rc::new(String::from("abc"))); + let expected_str = OwnedValue::Text(Rc::new(String::from("abcoabc"))); + assert_eq!( + exec_replace(&input_str, &pattern_str, &replace_str), + expected_str + ); + + let input_str = OwnedValue::Text(Rc::new(String::from("bob"))); + let pattern_str = OwnedValue::Text(Rc::new(String::from("a"))); + let replace_str = OwnedValue::Text(Rc::new(String::from("b"))); + let expected_str = OwnedValue::Text(Rc::new(String::from("bob"))); + assert_eq!( + exec_replace(&input_str, &pattern_str, &replace_str), + expected_str + ); + + let input_str = OwnedValue::Text(Rc::new(String::from("bob"))); + let pattern_str = OwnedValue::Text(Rc::new(String::from(""))); + let replace_str = OwnedValue::Text(Rc::new(String::from("a"))); + let expected_str = OwnedValue::Text(Rc::new(String::from("bob"))); + assert_eq!( + exec_replace(&input_str, &pattern_str, &replace_str), + expected_str + ); + + let input_str = OwnedValue::Text(Rc::new(String::from("bob"))); + let pattern_str = OwnedValue::Null; + let replace_str = OwnedValue::Text(Rc::new(String::from("a"))); + let expected_str = OwnedValue::Null; + assert_eq!( + exec_replace(&input_str, &pattern_str, &replace_str), + expected_str + ); + + let input_str = OwnedValue::Text(Rc::new(String::from("bo5"))); + let pattern_str = OwnedValue::Integer(5); + let replace_str = OwnedValue::Text(Rc::new(String::from("a"))); + let expected_str = OwnedValue::Text(Rc::new(String::from("boa"))); + assert_eq!( + exec_replace(&input_str, &pattern_str, &replace_str), + expected_str + ); + + let input_str = OwnedValue::Text(Rc::new(String::from("bo5.0"))); + let pattern_str = OwnedValue::Float(5.0); + let replace_str = OwnedValue::Text(Rc::new(String::from("a"))); + let expected_str = OwnedValue::Text(Rc::new(String::from("boa"))); + assert_eq!( + exec_replace(&input_str, &pattern_str, &replace_str), + expected_str + ); + + let input_str = OwnedValue::Text(Rc::new(String::from("bo5"))); + let pattern_str = OwnedValue::Float(5.0); + let replace_str = OwnedValue::Text(Rc::new(String::from("a"))); + let expected_str = OwnedValue::Text(Rc::new(String::from("bo5"))); + assert_eq!( + exec_replace(&input_str, &pattern_str, &replace_str), + expected_str + ); + + let input_str = OwnedValue::Text(Rc::new(String::from("bo5.0"))); + let pattern_str = OwnedValue::Float(5.0); + let replace_str = OwnedValue::Float(6.0); + let expected_str = OwnedValue::Text(Rc::new(String::from("bo6.0"))); + assert_eq!( + exec_replace(&input_str, &pattern_str, &replace_str), + expected_str + ); + + // todo: change this test to use (0.1 + 0.2) instead of 0.3 when decimals are implemented. + let input_str = OwnedValue::Text(Rc::new(String::from("tes3"))); + let pattern_str = OwnedValue::Integer(3); + let replace_str = OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Float(0.3)))); + let expected_str = OwnedValue::Text(Rc::new(String::from("tes0.3"))); + assert_eq!( + exec_replace(&input_str, &pattern_str, &replace_str), + expected_str + ); + } } diff --git a/testing/scalar-functions.test b/testing/scalar-functions.test index f7af9843a..b48f2eeeb 100755 --- a/testing/scalar-functions.test +++ b/testing/scalar-functions.test @@ -167,6 +167,18 @@ do_execsql_test lower-null { select lower(null) } {} +do_execsql_test replace { + select replace('test', 'test', 'example') +} {example} + +do_execsql_test replace-number { + select replace('tes3', 3, 0.3) +} {tes0.3} + +do_execsql_test replace-null { + select replace('test', null, 'example') +} {} + do_execsql_test hex { select hex('limbo') } {6C696D626F}